From 73b30604da6cd7d24865fd5aa9c855a89c69377e Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Mon, 22 Jan 2024 13:57:56 -0500 Subject: [PATCH] Do not try to legalize transposed convolution (#2721) Currently transposed convolution is not handled correctly by `TorchToTosa`. This PR allows transposed convolutions to pass through the conversion so that they can be handled by other conversion passes later in a pipeline. An example input which produces a compilation error is: ``` func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { %true = torch.constant.bool true %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32> %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> return %output : !torch.vtensor<[1,64,2,200],f32> } ``` This MLIR produces an error about a cast operation with a size mismatch when passed through `torch-to-tosa`: ``` error: 'tensor.cast' op operand type 'tensor<1x64x1x50xf32>' and result type 'tensor<1x64x2x200xf32>' are cast incompatible ``` --------- Co-authored-by: Srinath Avadhanula --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 8 ++++++++ .../TorchToTosa/conv2d_transpose.mlir | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 test/Conversion/TorchToTosa/conv2d_transpose.mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 919fe73b2..b49c9af8a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1850,6 +1850,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: non-constant value for transposed not supported"); + if (transposed) + return rewriter.notifyMatchFailure( + op, "Unimplemented: transposed convolution not supported"); + auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); diff --git a/test/Conversion/TorchToTosa/conv2d_transpose.mlir b/test/Conversion/TorchToTosa/conv2d_transpose.mlir new file mode 100644 index 000000000..678034cb8 --- /dev/null +++ b/test/Conversion/TorchToTosa/conv2d_transpose.mlir @@ -0,0 +1,18 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics + +// The following test ensures that a tranposed convolution op is not +// lowered in the torch-to-tosa conversion pass. + +func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32> + %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> + %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} + %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> + return %output : !torch.vtensor<[1,64,2,200],f32> +} +