diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 8919df43a..2712f0964 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1292,14 +1292,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( }); patterns.onOp( "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - std::string autoPad; - if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) - return failure(); - if (autoPad != "NOTSET") { - // TODO: Add support for `auto_pad` != "NOTSET" - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); - } Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1349,20 +1341,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( defaultStrides.push_back(1); defaultDilations.push_back(1); } - // Padding for the beginning and ending along each spatial axis, it can - // take any value greater than or equal to 0. The value represent the - // number of pixels added to the beginning and end part of the - // corresponding axis. pads format should be as follow [x1_begin, - // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added - // at the beginning of axis i and xi_end, the number of pixels added at - // the end of axis i. - if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { - return failure(); - } - if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { - return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); - } if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) { return failure(); @@ -1379,6 +1357,46 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); } + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + auto inputTensorType = cast(input.getType()); + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (autoPad == "NOTSET") { + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + } else if (autoPad == "VALID") { + padding = defaultPadding; + } else { + const bool isSameLower = autoPad == "SAME_LOWER"; + const unsigned spatialRank = rank - 2; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatialRank); + for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatialRank + dimIdx] = totalPad - padding[dimIdx]; + } + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; @@ -1452,8 +1470,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value modeVal = rewriter.create( binder.getLoc(), rewriter.getStringAttr("constant")); Value constantValue; - auto inputTensorType = - cast(input.getType()); + if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(0)); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 10cca7f80..6cc0cf0ec 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1062,6 +1062,93 @@ func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32 // ----- +// CHECK-LABEL: @test_conv_with_autopad +func.func @test_conv_with_autopad(%arg0: !torch.vtensor<[1,1,12,7],f32>, %arg1: !torch.vtensor<[1,1,2,3],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 0 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 4 + // CHECK: %[[C2_0:.*]] = torch.constant.int 3 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,12,7],f32>, !torch.vtensor<[1,1,2,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,3],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 3 : si64], torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.strides = [4 : si64, 3 : si64]} : (!torch.vtensor<[1,1,12,7],f32>, !torch.vtensor<[1,1,2,3],f32>) -> !torch.vtensor<[1,1,3,3],f32> + return %0 : !torch.vtensor<[1,1,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_autopad_asymmetric +func.func @test_conv_with_autopad_asymmetric(%arg0: !torch.vtensor<[1,1,15,9],f32>, %arg1: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[str:.*]] = torch.constant.str "constant" + // CHECK: %[[float0:.*]] = torch.constant.float 0.000 + // CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,15,9],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[1,1,16,12],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C4_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,16,12],f32>, !torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + // CHECK: return %[[Conv]] + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [4 : si64, 4 : si64], torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.strides = [4 : si64, 4 : si64]} : (!torch.vtensor<[1,1,15,9],f32>, !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_autopad_asymmetric_lower +func.func @test_conv_with_autopad_asymmetric_lower(%arg0: !torch.vtensor<[1,1,15,9],f32>, %arg1: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int2]], %[[int1]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[str:.*]] = torch.constant.str "constant" + // CHECK: %[[float0:.*]] = torch.constant.float 0.000 + // CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,15,9],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[1,1,16,12],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C4_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,16,12],f32>, !torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + // CHECK: return %[[Conv]] + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [4 : si64, 4 : si64], torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.strides = [4 : si64, 4 : si64]} : (!torch.vtensor<[1,1,15,9],f32>, !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_bias_strides_padding func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C3:.*]] = torch.constant.int 3