From 77ab31641fde8cabd6403648080c69d315ffd67b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 3 Jun 2022 18:08:59 +0530 Subject: [PATCH] [MLIR][TORCH] Add decomposition of aten.numpy_T op This commit adds the decomposition of `aten.numpy_T` op into `aten.t` or `aten.permute` op. Signed-Off By: Vivek Khandelwal --- e2e_testing/torchscript/xfail_sets.py | 4 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 22 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 25 +++++ .../Transforms/MaximizeValueSemantics.cpp | 2 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 13 +++ .../jit_ir/build_tools/shape_lib_gen.py | 6 ++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 98 +++++++++++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 27 +++++ 10 files changed, 198 insertions(+), 2 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 9c723ca19..f9a65a810 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -160,4 +160,8 @@ TOSA_PASS_SET = { "BaddbmmWithBetaModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNStaticModule_basic", + "NumpyTRankNDynamicModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c18a45dce..bf955203a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6002,6 +6002,28 @@ def Torch_AtenTOp : Torch_Op<"aten.t", [ }]; } +def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenNumpyTOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenFullOp : Torch_Op<"aten.full", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 10e341274..1db2ed9ec 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1957,6 +1957,29 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.numpy_T` op into `aten.permute` op. +class DecomposeAtenNumpyTOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNumpyTOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + int64_t inputRank = getTensorRank(self); + + SmallVector dimListElements; + for (int64_t i = inputRank - 1; i >= 0; i--) + dimListElements.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + dimListElements); + rewriter.replaceOpWithNewOp(op, op.getType(), self, dimList); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -2102,6 +2125,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index cd866c0fc..6168f3d0b 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -37,7 +37,7 @@ static bool isViewLikeOp(Operation *op) { Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, - TensorStaticInfoCastOp, AtenToDtypeLayoutOp>(op); + TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp>(op); } namespace { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index ea0401bbb..5cecc1189 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -642,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation( AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, - PrimAbsScalarOp>(op)) { + PrimAbsScalarOp, AtenNumpyTOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 03d485859..eae395412 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5628,6 +5628,19 @@ module { %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list, !torch.int, !torch.int) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.numpy_T"(%arg0: !torch.list) -> !torch.list { + %int0 = torch.constant.int 0 + %true = torch.constant.bool true + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + torch.prim.Loop %1, %true, init() { + ^bb0(%arg1: !torch.int): + %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int + torch.aten.insert.t %0, %int0, %2 : !torch.list, !torch.int, !torch.int + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.matmul"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 2478eadff..afd1d3297 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -529,6 +529,12 @@ def aten〇transpose〇int(self: List[int], dim0: int, dim1: int) -> List[int]: def aten〇t(self: List[int]) -> List[int]: return upstream_shape_functions.transpose(self, 0, 1) +def aten〇numpy_T(self: List[int]) -> List[int]: + result_shape: List[int] = [] + for i in self: + result_shape.insert(0, i) + return result_shape + def aten〇matmul(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.matmul(self, other) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4ca88a3c9..cd9b149af 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -465,6 +465,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::t : (Tensor) -> (Tensor)") + emit("aten::numpy_T : (Tensor) -> (Tensor)") emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index b186b9b08..290b29358 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2167,3 +2167,101 @@ class BaddbmmBroadcast2DInputModule(torch.nn.Module): @register_test_case(module_factory=lambda: BaddbmmBroadcast2DInputModule()) def BaddbmmBroadcast2DInputModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7), tu.rand(5, 2, 9), tu.rand(5, 9, 7)) + + +# ============================================================================== + + +class NumpyTRankNStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5, 6], torch.float32, True), + ]) + def forward(self, lhs): + return torch.ops.aten.numpy_T(lhs) + + +@register_test_case(module_factory=lambda: NumpyTRankNStaticModule()) +def NumpyTRankNStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, 6)) + + +class NumpyTRankNDynamicModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, lhs): + return torch.ops.aten.numpy_T(lhs) + + +@register_test_case(module_factory=lambda: NumpyTRankNDynamicModule()) +def NumpyTRankNDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, 6, 2)) + + +class NumpyTRank2Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, lhs): + return torch.ops.aten.numpy_T(lhs) + + +@register_test_case(module_factory=lambda: NumpyTRank2Module()) +def NumpyTRank2Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +class NumpyTRank1Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + def forward(self, lhs): + return torch.ops.aten.numpy_T(lhs) + + +@register_test_case(module_factory=lambda: NumpyTRank1Module()) +def NumpyTRank1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3)) + + +class NumpyTRank0Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float32, True), + ]) + def forward(self, lhs): + return torch.ops.aten.numpy_T(lhs) + + +@register_test_case(module_factory=lambda: NumpyTRank0Module()) +def NumpyTRank0Module_basic(module, tu: TestUtils): + module.forward(torch.tensor(7, dtype=torch.float32)) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 59c6e6f9a..b261c96d1 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1036,3 +1036,30 @@ func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !tor %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- +// CHECK-LABEL: func @torch.aten.numpy_T$rank_two( +// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[4,5],f32> { +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUT:.*]] = torch.aten.permute %[[SELF]], %[[DIMS]] : !torch.vtensor<[5,4],f32>, !torch.list -> !torch.vtensor<[4,5],f32> +// CHECK: return %[[OUT]] : !torch.vtensor<[4,5],f32> +func.func @torch.aten.numpy_T$rank_two(%arg0: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[4,5],f32> { + %0 = torch.aten.numpy_T %arg0 : !torch.vtensor<[5,4],f32> -> !torch.vtensor<[4,5],f32> + return %0 : !torch.vtensor<[4,5],f32> +} + +// ----- +// CHECK-LABEL: func @torch.aten.numpy_T$rank_three( +// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[5,4,3],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]], %[[CST1]], %[[CST0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[OUT:.*]] = torch.aten.permute %[[SELF]], %[[DIMS]] : !torch.vtensor<[5,4,3],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5],f32> +func.func @torch.aten.numpy_T$rank_three(%arg0: !torch.vtensor<[5,4,3],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.numpy_T %arg0 : !torch.vtensor<[5,4,3],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +}