mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for aten.expand
This commit adds decomposition of `aten.Expand` to `aten.BroadcastTo` op. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/399/head snapshot-20211103.62
parent
ef897dbb19
commit
2ce47dc8e4
|
@ -446,6 +446,22 @@ class BroadcastToModule(torch.nn.Module):
|
||||||
def BroadcastToModule_basic(module, tu: TestUtils):
|
def BroadcastToModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 1, 1))
|
module.forward(tu.rand(3, 1, 1))
|
||||||
|
|
||||||
|
class ExpandModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.expand([1, -1, -1, 4])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ExpandModule())
|
||||||
|
def ExpandModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 1, 1))
|
||||||
|
|
||||||
class OnesModuleInt(torch.nn.Module):
|
class OnesModuleInt(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -142,6 +142,26 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose torch.expand into torch.broadcast_to op.
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenExpandOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
bool implicit = false;
|
||||||
|
if (!matchPattern(op.implicit(), m_TorchConstantBool(&implicit)) ||
|
||||||
|
implicit) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: requires implicit to be false");
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<AtenBroadcastToOp>(op, op.getType(), op.self(),
|
||||||
|
op.size());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -155,6 +175,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||||
|
patterns.add<DecomposeAtenExpandOp>(context);
|
||||||
|
target.addIllegalOp<AtenExpandOp>();
|
||||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||||
int lhsRank = getTensorRank(op.self());
|
int lhsRank = getTensorRank(op.self());
|
||||||
|
|
|
@ -92,7 +92,7 @@ public:
|
||||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||||
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp,
|
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp,
|
||||||
AtenViewOp>(op)) {
|
AtenViewOp, AtenExpandOp>(op)) {
|
||||||
// AtenContiguousOp might return a view, so this is conservatively
|
// AtenContiguousOp might return a view, so this is conservatively
|
||||||
// correct. We could potentially be more precise and identify the cases
|
// correct. We could potentially be more precise and identify the cases
|
||||||
// that it does not return a view and treat those as having value
|
// that it does not return a view and treat those as having value
|
||||||
|
|
Loading…
Reference in New Issue