From 59bade337659d5dab541381252636fbe763cf8d7 Mon Sep 17 00:00:00 2001 From: Umang Yadav <29876643+umangyadav@users.noreply.github.com> Date: Mon, 17 Jun 2024 01:47:16 -0400 Subject: [PATCH] [ONNX] Add missing "Abs" in GlobalLpPool (#3460) Taking `abs` is required to mimic same logic as onnx/onnxruntime. Without `abs`, it wouldn't produce correct results for negative values. Reference code : https://github.com/microsoft/onnxruntime/blob/f5b6f6dc26a55ddf7523d832ac5dc56930225264/onnxruntime/core/providers/cpu/nn/pool_functors.h#L604 https://github.com/onnx/onnx/blob/375c161c67855fea9612c15b83ebff40fca838a4/onnx/reference/ops/op_lp_pool.py#L31 --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 14 +++++++++----- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 10 ++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index fb05c2985..6d4ea74f0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1631,7 +1631,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); }); patterns.onOp( - "GlobalLpPool", 1, + "GlobalLpPool", 2, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; @@ -1647,6 +1647,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } ArrayRef inputShape = inputTensorType.getSizes(); unsigned inputRank = inputShape.size(); + // only handle 2D, 3D and 5D pooling cases + if (inputRank > 5 or inputRank < 3) { + return failure(); + } if (!resultType || !resultType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected result type having sizes"); @@ -1693,11 +1697,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; + Value abs = rewriter.create(binder.getLoc(), + inputTensorType, operand); Value pv = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), p)); Value pow = rewriter.create( - binder.getLoc(), inputTensorType, operand, pv); + binder.getLoc(), inputTensorType, abs, pv); Value avgPool; if (inputRank == 3) { avgPool = rewriter.create( @@ -1710,13 +1716,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), resultType, pow, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstOne); - } else if (inputRank == 5) { + } else { // inputRank == 5 avgPool = rewriter.create( binder.getLoc(), resultType, pow, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstOne); - } else { - return failure(); } Value invP = rewriter.create( binder.getLoc(), rewriter.getType(), diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 479f28021..19c519082 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1032,7 +1032,7 @@ func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) // ----- // CHECK-LABEL: @test_globallppool -func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C5:.*]] = torch.constant.int 5 @@ -1043,8 +1043,9 @@ func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vte // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5,5],f32> -> !torch.vtensor<[1,3,5,5],f32> // CHECK: %[[CP:.*]] = torch.constant.int 2 - // CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32> + // CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32> // CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[C1]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,1,1],f32> // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 // CHECK: torch.aten.pow.Tensor_Scalar %[[AVGPOOL]], %[[INVP]] : !torch.vtensor<[1,3,1,1],f32>, !torch.float -> !torch.vtensor<[1,3,1,1],f32> @@ -1055,7 +1056,7 @@ func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vte // ----- // CHECK-LABEL: @test_globallppool_1d -func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C5:.*]] = torch.constant.int 5 @@ -1064,8 +1065,9 @@ func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vt // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]] : (!torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5],f32> -> !torch.vtensor<[1,3,5],f32> // CHECK: %[[CP:.*]] = torch.constant.int 2 - // CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32> + // CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32> // CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool1d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,3,5],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,1],f32> // CHECK: %[[MUL:.*]] = torch.aten.mul.Scalar %[[AVGPOOL]], %[[E1]] : !torch.vtensor<[1,3,1],f32>, !torch.int -> !torch.vtensor<[1,3,1],f32> // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01