mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for aten.expand
This commit adds decomposition of `aten.Expand` to `aten.BroadcastTo` op. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/399/head snapshot-20211103.62
parent
ef897dbb19
commit
2ce47dc8e4
|
@ -446,6 +446,22 @@ class BroadcastToModule(torch.nn.Module):
|
|||
def BroadcastToModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1))
|
||||
|
||||
class ExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.expand([1, -1, -1, 4])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ExpandModule())
|
||||
def ExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1))
|
||||
|
||||
class OnesModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -142,6 +142,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.expand into torch.broadcast_to op.
|
||||
namespace {
|
||||
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenExpandOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool implicit = false;
|
||||
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
|
||||
implicit) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: requires implicit to be false");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
|
||||
op.size());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -155,6 +175,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenExpandOp>(context);
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
int lhsRank = getTensorRank(op.self());
|
||||
|
|
|
@ -92,7 +92,7 @@ public:
|
|||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp,
|
||||
AtenViewOp>(op)) {
|
||||
AtenViewOp, AtenExpandOp>(op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
|
|
Loading…
Reference in New Issue