diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index a61f041d8..85dbfdac1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1690,20 +1690,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); - if (autoPad != "NOTSET") { - // TODO: Add support for `auto_pad` != "NOTSET" - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); - } - SmallVector outputShape; - if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) - return failure(); - if (outputShape.size()) { - // TODO: Add support for non-None output_shape value. - return rewriter.notifyMatchFailure( - binder.op, - "unsupported conversion: output_shape should be absent"); - } Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1737,6 +1723,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } } } + } else { + for (unsigned i = 0; i < weightShape.size() - 2; i++) { + kernelShape.push_back(weightShape[i + 2]); + } } // Determine the rank of input tensor. @@ -1746,7 +1736,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; - SmallVector padding, strides, dilations, outputPadding; + SmallVector padding, strides, dilations, outputPadding, + outputShape; SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; for (unsigned i = 0; i < rank - 2; i++) { @@ -1762,13 +1753,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added // at the beginning of axis i and xi_end, the number of pixels added at // the end of axis i. - if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { - return failure(); - } - if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { - return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); - } if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) { return failure(); @@ -1794,7 +1778,60 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "output_padding list size does not match the number of axes"); } + auto inputTensorType = cast(input.getType()); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + ArrayRef inputShape = inputTensorType.getSizes(); + if (autoPad == "VALID") { + // Zero padding. + padding = defaultPadding; + } else if (autoPad == "NOTSET") { + // Explicit padding; read pads with defaults. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) + return failure(); + } else { // autopad == SAME_UPPER or SAME_LOWER + // Auto-padding; output_shape defaults to input_shape * strides. + SmallVector defaultOutputShape; + for (unsigned i = 0; i < rank - 2; i++) { + defaultOutputShape.push_back(inputShape[2 + i] * strides[i]); + } + if (binder.s64IntegerArrayAttr(outputShape, "output_shape", + defaultOutputShape)) + return failure(); + SmallVector paddingEnd; + for (unsigned i = 0; i < rank - 2; i++) { + int64_t totalPadding = + strides[i] * (inputShape[2 + i] - 1) + outputPadding[i] + + ((kernelShape[i] - 1) * dilations[i] + 1) - outputShape[i]; + if (totalPadding % 2) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis. + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: the combination of stride, " + "input_shape, kernel_shape, dilation, output_padding and " + "output_shape caused auto-padding to produce asymmetric " + "padding which isn't currently supported."); + } + int64_t half = totalPadding / 2; + int64_t remainder = totalPadding - half; + if (autoPad == "SAME_UPPER") { + padding.push_back(half); + paddingEnd.push_back(remainder); + } else { + padding.push_back(remainder); + paddingEnd.push_back(half); + } + } + padding.insert(padding.end(), paddingEnd.begin(), paddingEnd.end()); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; if (padding.size() != 2 * (rank - 2)) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index d9c2df1d8..5e62efa00 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1329,6 +1329,78 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc // ----- +// CHECK-LABEL: @test_convtranspose_autopad_same_upper + func.func @test_convtranspose_autopad_same_upper(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_3:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_4:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,6,6],f32> + %4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_UPPER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> + return %4 : !torch.vtensor<[1,2,6,6],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_autopad_same_lower + func.func @test_convtranspose_autopad_same_lower(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_3:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_4:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,6,6],f32> + %4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_LOWER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> + return %4 : !torch.vtensor<[1,2,6,6],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_autopad_valid + func.func @test_convtranspose_autopad_valid(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_2:.*]] = torch.constant.int 2 + // CHECK: %[[C0_3:.*]] = torch.constant.int 0 + // CHECK: %[[C0_4:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_3]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,8,8],f32> + %4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="VALID", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32> + return %4 : !torch.vtensor<[1,2,8,8],f32> + } + +// ----- + // CHECK-LABEL: @test_batchnorm_epsilon func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FALSE:.*]] = torch.constant.bool false