From 2ce47dc8e4d6abb9cd6ed51ecddc06f002228b2e Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 2 Nov 2021 22:18:29 +0530 Subject: [PATCH] [TORCH][MLIR] Add E2E support for aten.expand This commit adds decomposition of `aten.Expand` to `aten.BroadcastTo` op. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/basic.py | 16 ++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 22 +++++++++++++++++++ .../Transforms/MaximizeValueSemantics.cpp | 2 +- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 5200df6a2..e70ee800d 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -446,6 +446,22 @@ class BroadcastToModule(torch.nn.Module): def BroadcastToModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 1)) +class ExpandModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, 1], torch.float32, True), + ]) + def forward(self, x): + return x.expand([1, -1, -1, 4]) + + +@register_test_case(module_factory=lambda: ExpandModule()) +def ExpandModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 1)) class OnesModuleInt(torch.nn.Module): def __init__(self): diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e5aa6503e..a402cc50e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -142,6 +142,26 @@ public: }; } // namespace +// Decompose torch.expand into torch.broadcast_to op. +namespace { +class DecomposeAtenExpandOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExpandOp op, + PatternRewriter &rewriter) const override { + bool implicit = false; + if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) || + implicit) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires implicit to be false"); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + op.size()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -155,6 +175,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addDynamicallyLegalOp([](AtenMatmulOp op) { int lhsRank = getTensorRank(op.self()); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 26ad5dd93..1066edc55 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -92,7 +92,7 @@ public: } else if (isa(op)) { + AtenViewOp, AtenExpandOp>(op)) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value