diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b6e9d9ba9..82002292e 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -591,25 +591,32 @@ public: auto weightShape = weightTy.getShape(); auto nDims = inputTy.getRank(); + auto weightDims = weightTy.getRank(); + auto kernelDims = weightDims - 2; + auto nSpatialDims = nDims - 2; auto convOutTy = outType; // Transpose weight SmallVector perm(nDims); SmallVector transposeShape(nDims); - for (int i = 0; i < nDims; i++) { - if (i < 2) - perm[i] = nDims - 2 + i; + // 1d: kernelDims = 1, [0, 1, 2] => [2, 1, 0] + // 2d: kernelDims = 2, [0, 1, 2, 3] => [2, 3, 1, 0] + // 3d: kernelDims = 3, [0, 1, 2, 3, 4] => [2, 3, 4, 1, 0] + for (int i = 0; i < weightDims; i++) { + if (i < kernelDims) + perm[i] = 2 + i; else - perm[i] = nDims - i - 1; + perm[i] = kernelDims + 1 - i; transposeShape[i] = weightShape[perm[i]]; } + auto reverseDim = llvm::to_vector<4>(llvm::seq(0, kernelDims)); auto transposeTy = RankedTensorType::get(transposeShape, weightTy.getElementType()); auto transposeOp = rewriter.create( op->getLoc(), transposeTy, weight, perm); auto reverseOp = rewriter.create( - op->getLoc(), transposeOp, ArrayRef{0, 1}); + op->getLoc(), transposeOp, reverseDim); // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index ad7889057..a56982714 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9110,6 +9110,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose3d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list {\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11797,10 +11807,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose1d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose3d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f62bbe562..da49e2d77 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3633,6 +3633,25 @@ public: }; } // namespace +// Decompose aten.conv_transpose1d to aten.convolution +namespace { +class DecomposeAtenConvTranspose1dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose1dOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); + return success(); + } +}; +} // namespace + // Decompose aten.conv_transpose2d to aten.convolution namespace { class DecomposeAtenConvTranspose2dOp @@ -3652,6 +3671,25 @@ public: }; } // namespace +// Decompose aten.conv_transpose3d to aten.convolution +namespace { +class DecomposeAtenConvTranspose3dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose3dInputOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); + return success(); + } +}; +} // namespace + // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ @@ -7963,7 +8001,9 @@ public: DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 374b0f4e4..ffc45a1be 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -428,7 +428,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7ccbdbee6..9bbeb9bef 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -911,6 +911,9 @@ STABLEHLO_PASS_SET = { "Convolution2DStaticModule_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dStaticModule_basic", "ConstantPad2dStaticModule_basic", "ConstantPadNdModule_basic", "ConstantPadNdPartialStaticModule_basic", @@ -2662,6 +2665,8 @@ ONNX_XFAIL_SET = { "PrimsIotaModule_basic", # Failure - unknown "BernoulliModule_basic", + "Conv_Transpose1dModule_basic", + "Conv_Transpose3dModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 01a38c0fe..af8763e97 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1548,6 +1548,12 @@ def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Option def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) +def aten〇conv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]: + return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) + +def aten〇conv_transpose3d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> List[int]: + return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) + def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) @@ -3538,6 +3544,10 @@ def aten〇conv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv_transpose1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), @@ -3549,6 +3559,10 @@ def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], w input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv_transpose3d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + convolution_kwargs = { "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} @check_dtype_function( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index e99525c32..b157f91ef 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -760,6 +760,66 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3)) +class Conv_Transpose1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d( + inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dModule()) +def Conv_Transpose1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2)) + + +class Conv_Transpose1dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 6], torch.float32, True), + ([2, 5, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d( + inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dStaticModule()) +def Conv_Transpose1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2)) + + class Conv_Transpose2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -790,6 +850,96 @@ def Conv_Transpose2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) +class Conv_Transpose2dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose2d( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose2dStaticModule()) +def Conv_Transpose2dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + + +class Conv_Transpose3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d( + inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2)) + + +class Conv_Transpose3dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6, 7], torch.float32, True), + ([2, 5, 2, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d( + inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dStaticModule()) +def Conv_Transpose3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2)) + + class UpSampleNearest2d(torch.nn.Module): def __init__(self): super().__init__()