diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 510ce8121..421b4edc9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return failure(); }); + patterns.onOp( + "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + + Torch::ValueTensorType resultType; + Value input, weight; + int64_t group; + if (binder.tensorOperands(input, weight) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto weightTensorType = weight.getType().cast(); + if (!weightTensorType || !weightTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + } + ArrayRef weightShape = weightTensorType.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations; + SmallVector defaultPadding, defaultStrides, defaultDilations; + for (unsigned i = 0; i < rank - 2; i++) { + defaultPadding.push_back(0); + defaultStrides.push_back(1); + defaultDilations.push_back(1); + } + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) { + return failure(); + } + if (dilations.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { + return failure(); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + + SmallVector cstPadding, cstStrides, cstDilations, + cstOutputPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + } + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + for (int64_t i : dilations) { + cstDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + cstOutputPadding = {cstZero, cstZero}; + + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value outputPaddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstOutputPadding); + Value transposed = + rewriter.create(binder.getLoc(), false); + Value bias; + if (binder.op->getNumOperands() == 3) { + if (binder.tensorOperandAtIndex(bias, 2)) { + return failure(); + } + } else { + bias = rewriter.create(binder.getLoc()); + } + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); + patterns.onOp( + "ConvTranspose", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + SmallVector outputShape; + if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) + return failure(); + if (outputShape.size()) { + // TODO: Add support for non-None output_shape value. + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: output_shape should be absent"); + } + Torch::ValueTensorType resultType; + Value input, weight; + int64_t group; + if (binder.tensorOperands(input, weight) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto weightTensorType = weight.getType().cast(); + if (!weightTensorType || !weightTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + } + ArrayRef weightShape = weightTensorType.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations, outputPadding; + SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; + for (unsigned i = 0; i < rank - 2; i++) { + defaultPadding.push_back(0); + defaultStrides.push_back(1); + defaultDilations.push_back(1); + defaultOutputPadding.push_back(0); + } + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) { + return failure(); + } + if (dilations.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { + return failure(); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(outputPadding, "output_padding", + defaultOutputPadding)) { + return failure(); + } + if (outputPadding.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "output_padding list size does not match the number of axes"); + } + + SmallVector cstPadding, cstStrides, cstDilations, + cstOutputPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + } + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + for (int64_t i : dilations) { + cstDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : outputPadding) { + cstOutputPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value outputPaddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstOutputPadding); + Value transposed = + rewriter.create(binder.getLoc(), true); + Value bias; + if (binder.op->getNumOperands() == 3) { + if (binder.tensorOperandAtIndex(bias, 2)) { + return failure(); + } + } else { + bias = rewriter.create(binder.getLoc()); + } + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); patterns.onOp("Cos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 5f36d1bbf..438d6c7ad 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -462,3 +462,133 @@ func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32> %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> return %0 : !torch.vtensor<[1,3,31,31,31],f32> } + +// CHECK-LABEL: @test_conv_with_strides_no_padding +func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,2],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> + return %0 : !torch.vtensor<[1,1,3,2],f32> +} + +// CHECK-LABEL: @test_conv_with_strides_padding +func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// CHECK-LABEL: @test_convtranspose_dilations +func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.dilations = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> + return %0 : !torch.vtensor<[1,1,5,5],f32> +} + +// CHECK-LABEL: @test_convtranspose +func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,5,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> + return %0 : !torch.vtensor<[1,2,5,5],f32> +} + +// CHECK-LABEL: @test_convtranspose_pad + func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,10,8],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_padding = [1 : si64, 1 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + return %0 : !torch.vtensor<[1,2,10,8],f32> + } + +// CHECK-LABEL: @test_convtranspose_pads + func.func @test_convtranspose_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,7,3],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> + return %0 : !torch.vtensor<[1,2,7,3],f32> + }