From c38308f3ef0a93dd0909ac9db86facec5b66f313 Mon Sep 17 00:00:00 2001 From: Alex Tsao <814943412@qq.com> Date: Mon, 22 Aug 2022 11:17:36 +0800 Subject: [PATCH] Add lowering for _convolution.deprecated (#1259) * Add lowering for _convolution.deprecated --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 34 ++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 17 +-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 + .../jit_ir/build_tools/shape_lib_gen.py | 5 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + python/torch_mlir_e2e_test/test_suite/conv.py | 112 ++++++++++++++++++ 7 files changed, 166 insertions(+), 9 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6b9ead8d5..235d910c9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ }]; } +def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$transposed, + AnyTorchListOfTorchIntType:$output_padding, + Torch_IntType:$groups, + Torch_BoolType:$benchmark, + Torch_BoolType:$deterministic, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 12, 1); + } + void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 12, 1); + } + }]; +} + def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 555b96853..b36938afc 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -927,13 +927,14 @@ public: }; } // namespace -// Decompose aten.convolution_overrideable to aten.convolution +// Decompose aten._convolution-like to aten.convolution namespace { -class DecomposeAten_ConvolutionOp - : public OpRewritePattern { +template +class DecomposeAten_ConvolutionLikeOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Aten_ConvolutionOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvolutionLikeOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( @@ -2542,8 +2543,10 @@ public: patterns.add(context); target.addIllegalOp(); patterns.add(context); - target.addIllegalOp(); - patterns.add(context); + target.addIllegalOp(); + patterns.add, + DecomposeAten_ConvolutionLikeOp>( + context); target.addIllegalOp(); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 80ae46614..8696e3c4c 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ff42c75bd..7f4355315 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6341,6 +6341,10 @@ module { %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list { + %0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { return %arg0 : !torch.list } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 7ec3b67f8..2d8ba3b6e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -940,7 +940,10 @@ def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[ def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) - + +def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: + return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) + def aten〇flip(self: List[int], dims: List[int]) -> List[int]: return self diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a2151fdec..49b21799c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -337,6 +337,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") + emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index d3bc77e71..cb92710d2 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -406,6 +406,118 @@ class _Convolution2DTF32Module(torch.nn.Module): def _Convolution2DTF32Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) +class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule()) +def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=True, + deterministic=False, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule()) +def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=True, + cudnn_enabled=False) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule()) +def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + +class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten._convolution(inputVec, + weight, + bias=None, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + benchmark=False, + deterministic=False, + cudnn_enabled=True) + +@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule()) +def _Convolution2DCudnnModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) + class ConvolutionModule2DGroups(torch.nn.Module): def __init__(self): super().__init__()