From 04c64793501e2147f470028305aebadb342778e5 Mon Sep 17 00:00:00 2001 From: Umang Yadav <29876643+umangyadav@users.noreply.github.com> Date: Fri, 14 Jun 2024 12:11:18 -0400 Subject: [PATCH] [ONNX] Add onnx parser for LpPool operator (#3449) Similar to https://github.com/llvm/torch-mlir/pull/3435 Solves https://github.com/nod-ai/SHARK-Turbine/issues/728 --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 116 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 56 +++++++++ 2 files changed, 172 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 555f7f650..87afc46bd 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1687,6 +1687,122 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); + patterns.onOp( + "LpPool", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + + Torch::ValueTensorType resultType; + Value operand; + int64_t ceilMode, p; + if (binder.tensorOperand(operand) || + binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || + binder.s64IntegerAttr(p, "p", 2) || + binder.tensorResultType(resultType)) + return failure(); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + // only 1D, 2D and 3D LpPool is supported. + if (rank > 5 or rank < 3) { + return failure(); + } + + SmallVector kernel, padding, strides, dilations; + SmallVector defaultPadding(2 * (rank - 2), 0); + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}) || + binder.s64IntegerArrayAttr(padding, "pads", defaultPadding) || + binder.s64IntegerArrayAttr( + strides, "strides", llvm::SmallVector(rank - 2, 1)) || + binder.s64IntegerArrayAttr(dilations, "dilations", {})) { + return failure(); + } + if (kernel.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + } + if (padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, + "padding list size does not match twice the number of axes"); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + if (dilations.size() > 0) { + return rewriter.notifyMatchFailure( + binder.op, "dilation is not supported by torch.aten.avgpool op " + "and therefore it is not supported for LpPool."); + } + + SmallVector cstKernel, cstPadding, cstStrides; + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value numElements = cstOne; + for (int64_t i : kernel) { + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + numElements = rewriter.create( + binder.getLoc(), rewriter.getType(), + cstKernel.back(), numElements); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + // onnx lp pool doesn't have countIncludePad attribute but set it to + // true so that in 1D case numElements is correctly undoes divison. For + // 2D/3D case, division is avoided by divison_override. + Value cstCountIncludePad = + rewriter.create(binder.getLoc(), true); + Value pv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), p)); + auto inputTensorType = cast(operand.getType()); + Value abs = rewriter.create(binder.getLoc(), + inputTensorType, operand); + Value pow = rewriter.create( + binder.getLoc(), inputTensorType, abs, pv); + Value avgPool; + if (rank == 3) { + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + avgPool = rewriter.create( + binder.getLoc(), resultType, avgPool, numElements); + } else if (rank == 4) { + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstOne); + } else { // rank == 5 + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstOne); + } + Value invP = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(double{1.0 / p})); + rewriter.replaceOpWithNewOp( + binder.op, resultType, avgPool, invP); + return success(); + }); + patterns.onOp( "LayerNormalization", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c1fff157b..fc79f88b1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -274,6 +274,62 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch. // ----- +// CHECK-LABEL: func.func @test_lppool_2d +func.func @test_lppool_2d(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[NE1:.*]] = torch.aten.mul %[[I2_1]], %[[NE]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[CEIL:.*]] = torch.constant.bool false + // CHECK: %[[CIP:.*]] = torch.constant.bool true + // CHECK: %[[P:.*]] = torch.constant.int 2 + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32,32],f32> -> !torch.vtensor<[1,3,32,32],f32> + // CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32,32],f32>, !torch.int -> !torch.vtensor<[1,3,32,32],f32> + // CHECK: %[[AVG:.*]] = torch.aten.avg_pool2d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]], %[[I1]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,31,31],f32> + // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.pow.Tensor_Scalar %[[AVG]], %[[INVP]] : !torch.vtensor<[1,3,31,31],f32>, !torch.float -> !torch.vtensor<[1,3,31,31],f32> + %0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lppool_1d +func.func @test_lppool_1d(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]] : (!torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]] : (!torch.int) -> !torch.list + // CHECK: %[[CEIL:.*]] = torch.constant.bool false + // CHECK: %[[CIP:.*]] = torch.constant.bool true + // CHECK: %[[P:.*]] = torch.constant.int 2 + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32],f32> -> !torch.vtensor<[1,3,32],f32> + // CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32],f32>, !torch.int -> !torch.vtensor<[1,3,32],f32> + // CHECK: %[[AVG:.*]] = torch.aten.avg_pool1d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]] : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32> + // CHECK: %[[POW_0:.*]] = torch.aten.mul.Scalar %[[AVG]], %[[NE]] : !torch.vtensor<[1,3,31],f32>, !torch.int -> !torch.vtensor<[1,3,31],f32> + // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.pow.Tensor_Scalar %[[POW_0]], %[[INVP]] : !torch.vtensor<[1,3,31],f32>, !torch.float -> !torch.vtensor<[1,3,31],f32> + %0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> + return %0 : !torch.vtensor<[1,3,31],f32> +} + +// ----- + // CHECK-LABEL : func.func @test_layer_norm func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {