mirror of https://github.com/llvm/torch-mlir
[ONNX][MLIR] Fix padding size constraint for onnx.maxpool op (#2782)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/2792/head
parent
d452c4f4c0
commit
b7a0329676
|
@ -201,9 +201,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.op, "kernel list size does not match the number of axes");
|
binder.op, "kernel list size does not match the number of axes");
|
||||||
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
|
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
|
||||||
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
|
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
|
||||||
if (padding.size() != 1 && padding.size() != rank - 2)
|
if (padding.size() != 1 && padding.size() != 2 * (rank - 2))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "padding list size does not match the number of axes");
|
binder.op, "padding list must contain (begin,end) pair for each "
|
||||||
|
"spatial axis");
|
||||||
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
|
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
|
||||||
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
|
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
|
||||||
if (strides.size() != 1 && strides.size() != rank - 2)
|
if (strides.size() != 1 && strides.size() != rank - 2)
|
||||||
|
|
|
@ -274,6 +274,29 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_maxpool_pad
|
||||||
|
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
|
||||||
|
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
|
||||||
|
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,64,56,56],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_gelu_default_1
|
// CHECK-LABEL: @test_gelu_default_1
|
||||||
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
||||||
|
|
Loading…
Reference in New Issue