From 1c2778dd56324f1b62a4084a3b1e3087f40a32cd Mon Sep 17 00:00:00 2001 From: Suraj Sudhir Date: Fri, 7 Jun 2024 09:54:39 -0700 Subject: [PATCH] [ONNX] Conv op adds support for asymmetric padding. (#3426) Supports asymmetric padding by performing a torch.nn.functional.pad on the input before performing the convolution. Signed-off-by: Suraj Sudhir --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 94 ++++++++++++++++--- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 6 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 4 +- 3 files changed, 85 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index eb6bfbe76..b26e1ea3a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -951,7 +951,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } - Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1034,23 +1033,94 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; + Value paddedInput = input; + Value paddingList; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } + paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstPadding); } else { + // ONNX offers pads in the format listing all starting dims, then all + // ending dims, e.g. {t, l, b, r} for conv2d. Torch by default accepts + // only starting dims, e.g. {t, l}. However, we can support padding at + // the beginning and end of each dimension by first performing + // torch.nn.functional.pad on the input. But this requires the pad + // values to be rearranged since torch pad() takes pads in the order + // rightmost dim start and end, then next to last, and so on, e.g. {l, + // r, t, b}. + bool matchedPads = true; for (unsigned i = 0; i < padding.size() / 2; i++) { if (padding[i] != padding[i + (padding.size() / 2)]) { - // TODO: Add support for different padding values for the - // beginning and ending along each spatial axis - return rewriter.notifyMatchFailure( - binder.op, - "unsupported conversion: padding values for the beginning " - "and ending along each spatial axis must be equal"); + matchedPads = false; + break; } - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + if (matchedPads) { + for (unsigned i = 0; i < padding.size() / 2; i++) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstPadding); + } else { + SmallVector padsRearrange; + SmallVector inputPaddingList; + for (uint32_t i = 0; i < padding.size() / 2; i++) { + padsRearrange.emplace_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + padsRearrange.emplace_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + padding[(padding.size() / 2) + i]))); + inputPaddingList.emplace_back( + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + } + // The conv op itself will have no padding since the actual padding + // is performed using the torch.pad preceding it. + paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + inputPaddingList); + Value padsSizeList = + rewriter + .create( + binder.getLoc(), + Torch::ListType::get( + rewriter.getType()), + padsRearrange) + .getResult(); + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("constant")); + Value constantValue; + auto inputTensorType = + cast(input.getType()); + if (isa(inputTensorType.getDtype())) + constantValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + if (isa(inputTensorType.getDtype())) + constantValue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + // Pad output shape must be computed explicitly from the pad values + SmallVector newInputShape(inputTensorType.getSizes()); + for (uint32_t i = 0; i < padding.size() / 2; i++) { + newInputShape[2 + i] += + padding[i] + padding[(padding.size() / 2) + i]; + } + auto padTy = rewriter.getType( + newInputShape, inputTensorType.getDtype()); + paddedInput = rewriter.create( + binder.getLoc(), padTy, input, padsSizeList, modeVal, + constantValue); } } for (int64_t i : dilations) { @@ -1065,10 +1135,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getI64IntegerAttr(0)); cstOutputPadding = {cstZero, cstZero}; - Value paddingList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstPadding); Value dilationsList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), @@ -1095,7 +1161,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( - binder.op, resultType, input, weight, bias, stridesList, + binder.op, resultType, paddedInput, weight, bias, stridesList, paddingList, dilationsList, transposed, outputPaddingList, cstGroup); return success(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 1a21d0c9c..3f437fc4c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -946,12 +946,12 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32> func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0_1:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list @@ -969,12 +969,12 @@ func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32 func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1_1:.*]] = torch.constant.int 1 // CHECK: %[[C1_2:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list @@ -992,12 +992,12 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C3:.*]] = torch.constant.int 3 // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 67b3b45a0..853e151d3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -60,12 +60,12 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] @@ -99,12 +99,12 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]