[TORCH][MLIR] Add E2E support for aten.expand

This commit adds decomposition of `aten.Expand` to `aten.BroadcastTo`
op.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/399/head snapshot-20211103.62
Gaurav Shukla 2021-11-02 22:18:29 +05:30 committed by Gaurav Shukla
parent ef897dbb19
commit 2ce47dc8e4
3 changed files with 39 additions and 1 deletions

View File

@ -446,6 +446,22 @@ class BroadcastToModule(torch.nn.Module):
def BroadcastToModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))
class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, 1], torch.float32, True),
])
def forward(self, x):
return x.expand([1, -1, -1, 4])
@register_test_case(module_factory=lambda: ExpandModule())
def ExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))
class OnesModuleInt(torch.nn.Module):
def __init__(self):

View File

@ -142,6 +142,26 @@ public:
};
} // namespace
// Decompose torch.expand into torch.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenExpandOp op,
PatternRewriter &rewriter) const override {
bool implicit = false;
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
implicit) {
return rewriter.notifyMatchFailure(
op, "unimplemented: requires implicit to be false");
}
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
op.size());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -155,6 +175,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.self());

View File

@ -92,7 +92,7 @@ public:
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
AtenTransposeIntOp, TensorStaticInfoCastOp,
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp,
AtenViewOp>(op)) {
AtenViewOp, AtenExpandOp>(op)) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value