[ONNX] Add missing "Abs" in GlobalLpPool (#3460)

Taking `abs` is required to mimic same logic as onnx/onnxruntime. 
Without `abs`, it wouldn't produce correct results for negative values. 

Reference code : 

f5b6f6dc26/onnxruntime/core/providers/cpu/nn/pool_functors.h (L604)


375c161c67/onnx/reference/ops/op_lp_pool.py (L31)
pull/3461/merge
Umang Yadav 2024-06-17 01:47:16 -04:00 committed by GitHub
parent 4555629246
commit 59bade3376
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 9 deletions

View File

@ -1631,7 +1631,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return failure();
});
patterns.onOp(
"GlobalLpPool", 1,
"GlobalLpPool", 2,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
@ -1647,6 +1647,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
unsigned inputRank = inputShape.size();
// only handle 2D, 3D and 5D pooling cases
if (inputRank > 5 or inputRank < 3) {
return failure();
}
if (!resultType || !resultType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected result type having sizes");
@ -1693,11 +1697,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstCeilMode = cstFalse;
Value cstCountIncludePad = cstFalse;
Value abs = rewriter.create<Torch::AtenAbsOp>(binder.getLoc(),
inputTensorType, operand);
Value pv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), p));
Value pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
binder.getLoc(), inputTensorType, operand, pv);
binder.getLoc(), inputTensorType, abs, pv);
Value avgPool;
if (inputRank == 3) {
avgPool = rewriter.create<Torch::AtenAvgPool1dOp>(
@ -1710,13 +1716,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
paddingList, cstCeilMode, cstCountIncludePad,
/*divisor_override=*/cstOne);
} else if (inputRank == 5) {
} else { // inputRank == 5
avgPool = rewriter.create<Torch::AtenAvgPool3dOp>(
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
paddingList, cstCeilMode, cstCountIncludePad,
/*divisor_override=*/cstOne);
} else {
return failure();
}
Value invP = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),

View File

@ -1032,7 +1032,7 @@ func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>)
// -----
// CHECK-LABEL: @test_globallppool
func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C5:.*]] = torch.constant.int 5
@ -1043,8 +1043,9 @@ func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vte
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5,5],f32> -> !torch.vtensor<[1,3,5,5],f32>
// CHECK: %[[CP:.*]] = torch.constant.int 2
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32>
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32>
// CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[C1]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,1,1],f32>
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
// CHECK: torch.aten.pow.Tensor_Scalar %[[AVGPOOL]], %[[INVP]] : !torch.vtensor<[1,3,1,1],f32>, !torch.float -> !torch.vtensor<[1,3,1,1],f32>
@ -1055,7 +1056,7 @@ func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vte
// -----
// CHECK-LABEL: @test_globallppool_1d
func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C5:.*]] = torch.constant.int 5
@ -1064,8 +1065,9 @@ func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vt
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5],f32> -> !torch.vtensor<[1,3,5],f32>
// CHECK: %[[CP:.*]] = torch.constant.int 2
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32>
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32>
// CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool1d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,3,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,1],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.Scalar %[[AVGPOOL]], %[[E1]] : !torch.vtensor<[1,3,1],f32>, !torch.int -> !torch.vtensor<[1,3,1],f32>
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01