[ONNX][MLIR] Fix padding size constraint for onnx.maxpool op (#2782)

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/2792/head
Gaurav Shukla 2024-01-23 19:23:01 +05:30 committed by GitHub
parent d452c4f4c0
commit b7a0329676
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 2 deletions

View File

@ -201,9 +201,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "kernel list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
if (padding.size() != 1 && padding.size() != rank - 2)
if (padding.size() != 1 && padding.size() != 2 * (rank - 2))
return rewriter.notifyMatchFailure(
binder.op, "padding list size does not match the number of axes");
binder.op, "padding list must contain (begin,end) pair for each "
"spatial axis");
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
if (strides.size() != 1 && strides.size() != rank - 2)

View File

@ -274,6 +274,29 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
// -----
// CHECK-LABEL: func.func @test_maxpool_pad
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32>
return %0 : !torch.vtensor<[1,64,56,56],f32>
}
// -----
// CHECK-LABEL: @test_gelu_default_1
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[STR1:.*]] = torch.constant.str "none"