From 4c0cd5c23d7df32303416b7065015676fa82e206 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 14 Mar 2022 13:42:37 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.expand_as op This commit decomposes `aten.expand_as` op into `aten.broadcast_to` op. Signed-Off By: Vivek Khandelwal --- e2e_testing/torchscript/basic.py | 39 +++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 19 +++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 6 +-- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 ++ .../jit_ir/build_tools/shape_lib_gen.py | 3 ++ test/Dialect/Torch/decompose-complex-ops.mlir | 17 ++++++++ 6 files changed, 85 insertions(+), 3 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 2d166ffbc..21a463246 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -1452,3 +1452,42 @@ class BincountMinlengthModule(torch.nn.Module): @register_test_case(module_factory=lambda: BincountMinlengthModule()) def BincountMinlengthModule_basic(module, tu: TestUtils): module.forward(torch.randint(5, (20,))) + +# ============================================================================== + +class ExpandAsFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1, 1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x, y): + return torch.ops.aten.expand_as(x, y) + + +@register_test_case(module_factory=lambda: ExpandAsFloatModule()) +def ExpandAsFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 1), tu.rand(3, 4, 5)) + + +class ExpandAsIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.expand_as(x, y) + + +@register_test_case(module_factory=lambda: ExpandAsIntModule()) +def ExpandAsIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(100, (1, 1, 1)), torch.randint(200, (4, 5, 6))) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7814b1273..1eda64812 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1429,6 +1429,23 @@ public: }; } // namespace +namespace { +class DecomposeAtenExpandAsOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExpandAsOp op, + PatternRewriter &rewriter) const override { + + auto sizeListType = + Torch::ListType::get(Torch::IntType::get(op.getContext())); + Value sizeList = + rewriter.create(op.getLoc(), sizeListType, op.other()); + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + sizeList); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -1534,6 +1551,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 5ba4463c3..79105abe0 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -515,9 +515,9 @@ ChangeResult TypeAnalyzer::visitOperation( AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, - AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, - AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>( - op)) { + AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, + AtenConstantPadNdOp, AtenIndexTensorOp, + ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(op)) { ValueKnowledge knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = operands[0]->getValue().dtype; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 128cc1254..165317ac8 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -1077,6 +1077,10 @@ module { } return %6 : !torch.list } + func @"__torch_mlir_shape_fn.aten.expand_as"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg1) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.broadcast_to"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 800dd1984..7545eb08a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -530,6 +530,9 @@ def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = - def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]: return upstream_shape_helpers.expand(self, size) +def aten〇expand_as(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_helpers.unary(other) + def aten〇broadcast_to(self: List[int], size: List[int]) -> List[int]: return upstream_shape_helpers.expand(self, size) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 033549399..222016d2c 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -662,3 +662,20 @@ func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtens %0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32> return %0 : !torch.vtensor<[?],f32> } + +// ----- +// CHECK-LABEL: func @torch.aten.expand_as( +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,1,1],f32>, %[[OTHER:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[OTHER]], %[[INT0]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[OTHER]], %[[INT1]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[OTHER]], %[[INT2]] : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[RES:.*]] = torch.aten.broadcast_to %[[INP]], %[[SIZE]] : !torch.vtensor<[?,1,1],f32>, !torch.list -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[RES]] : !torch.vtensor<[?,?,?],f32> +func @torch.aten.expand_as(%arg0: !torch.vtensor<[?,1,1],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.expand_as %arg0, %arg1 : !torch.vtensor<[?,1,1],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +}