mirror of https://github.com/llvm/torch-mlir
[ONNX] Add missing "Abs" in GlobalLpPool (#3460)
Taking `abs` is required to mimic same logic as onnx/onnxruntime. Without `abs`, it wouldn't produce correct results for negative values. Reference code :pull/3461/mergef5b6f6dc26/onnxruntime/core/providers/cpu/nn/pool_functors.h (L604)
375c161c67/onnx/reference/ops/op_lp_pool.py (L31)
parent
4555629246
commit
59bade3376
|
@ -1631,7 +1631,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"GlobalLpPool", 1,
|
||||
"GlobalLpPool", 2,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
|
@ -1647,6 +1647,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
unsigned inputRank = inputShape.size();
|
||||
// only handle 2D, 3D and 5D pooling cases
|
||||
if (inputRank > 5 or inputRank < 3) {
|
||||
return failure();
|
||||
}
|
||||
if (!resultType || !resultType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected result type having sizes");
|
||||
|
@ -1693,11 +1697,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value cstCeilMode = cstFalse;
|
||||
Value cstCountIncludePad = cstFalse;
|
||||
Value abs = rewriter.create<Torch::AtenAbsOp>(binder.getLoc(),
|
||||
inputTensorType, operand);
|
||||
Value pv = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), p));
|
||||
Value pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
|
||||
binder.getLoc(), inputTensorType, operand, pv);
|
||||
binder.getLoc(), inputTensorType, abs, pv);
|
||||
Value avgPool;
|
||||
if (inputRank == 3) {
|
||||
avgPool = rewriter.create<Torch::AtenAvgPool1dOp>(
|
||||
|
@ -1710,13 +1716,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad,
|
||||
/*divisor_override=*/cstOne);
|
||||
} else if (inputRank == 5) {
|
||||
} else { // inputRank == 5
|
||||
avgPool = rewriter.create<Torch::AtenAvgPool3dOp>(
|
||||
binder.getLoc(), resultType, pow, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad,
|
||||
/*divisor_override=*/cstOne);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
Value invP = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
|
|
|
@ -1032,7 +1032,7 @@ func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>)
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_globallppool
|
||||
func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||
|
@ -1043,8 +1043,9 @@ func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vte
|
|||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5,5],f32> -> !torch.vtensor<[1,3,5,5],f32>
|
||||
// CHECK: %[[CP:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32>
|
||||
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32>
|
||||
// CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[C1]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,1,1],f32>
|
||||
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: torch.aten.pow.Tensor_Scalar %[[AVGPOOL]], %[[INVP]] : !torch.vtensor<[1,3,1,1],f32>, !torch.float -> !torch.vtensor<[1,3,1,1],f32>
|
||||
|
@ -1055,7 +1056,7 @@ func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vte
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_globallppool_1d
|
||||
func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||
|
@ -1064,8 +1065,9 @@ func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vt
|
|||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5],f32> -> !torch.vtensor<[1,3,5],f32>
|
||||
// CHECK: %[[CP:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %arg0, %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32>
|
||||
// CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32>
|
||||
// CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool1d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,3,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,1],f32>
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.Scalar %[[AVGPOOL]], %[[E1]] : !torch.vtensor<[1,3,1],f32>, !torch.int -> !torch.vtensor<[1,3,1],f32>
|
||||
// CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01
|
||||
|
|
Loading…
Reference in New Issue