diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 4834e7af4..859215d28 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -201,9 +201,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "kernel list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(padding, "pads", {0})) return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); - if (padding.size() != 1 && padding.size() != rank - 2) + if (padding.size() != 1 && padding.size() != 2 * (rank - 2)) return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); if (binder.s64IntegerArrayAttr(strides, "strides", {1})) return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); if (strides.size() != 1 && strides.size() != rank - 2) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 59049e406..ba4487152 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -274,6 +274,29 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // ----- +// CHECK-LABEL: func.func @test_maxpool_pad +func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT3_0:.*]] = torch.constant.int 3 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_4:.*]] = torch.constant.int 2 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> + return %0 : !torch.vtensor<[1,64,56,56],f32> +} + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none"