[TORCH][MLIR] Add E2E support for `aten.reshape` op

This commit decomposes `aten.reshape` into `aten.view` op in the case of
value tensor type operand.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/463/head snapshot-20220202.244
Gaurav Shukla 2021-12-17 21:24:03 +05:30
parent 1b505cbac5
commit 0079901039
5 changed files with 70 additions and 7 deletions

View File

@ -39,7 +39,7 @@ from . import backprop
from . import reduction from . import reduction
from . import argmax from . import argmax
from . import matmul from . import matmul
from . import view from . import reshape_like
from . import scalar from . import scalar
from . import squeeze from . import squeeze
from . import slice_like from . import slice_like

View File

@ -123,3 +123,41 @@ class View1DFoldModule(torch.nn.Module):
@register_test_case(module_factory=lambda: View1DFoldModule()) @register_test_case(module_factory=lambda: View1DFoldModule())
def View1DFoldModule_basic(module, tu: TestUtils): def View1DFoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32)) module.forward(tu.rand(32))
# ==============================================================================
class ReshapeExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return a.reshape(12, 32)
@register_test_case(module_factory=lambda: ReshapeExpandModule())
def ReshapeExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(384))
# ==============================================================================
class ReshapeCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.reshape(a, (-1,))
@register_test_case(module_factory=lambda: ReshapeCollapseModule())
def ReshapeCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))

View File

@ -146,6 +146,25 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.self();
// TODO: Handle non value tensor type operands.
if (!input.getType().isa<ValueTensorType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only value tensor type operands are supported");
}
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
op.shape());
return success();
}
};
} // namespace
// Calculates the softmax function on the given `input` tensor. Softmax(x) = // Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)). // exp(x)/sum(exp(x)).
template <typename OpTy> template <typename OpTy>
@ -784,6 +803,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenExpandOp>(); target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenSizeOp>(context); patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>(); target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAtenReshapeOp>(context);
target.addIllegalOp<AtenReshapeOp>();
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context); patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>(); target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
patterns.add<DecomposeAtenTanhBackwardOp>(context); patterns.add<DecomposeAtenTanhBackwardOp>(context);

View File

@ -90,7 +90,7 @@ public:
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) { if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor); copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, } else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenReshapeOp,
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp, TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp, AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,

View File

@ -365,9 +365,11 @@ public:
maxDim, maxDim.dim(), maxDim.keepdim(), maxDim, maxDim.dim(), maxDim.keepdim(),
secondResDtype, operands, /*resNum=*/1); secondResDtype, operands, /*resNum=*/1);
} else if (auto view = dyn_cast<AtenViewOp>(op)) { } else if (auto view = dyn_cast<AtenViewOp>(op)) {
return visitReshapeLikeOp(view, operands); return visitReshapeLikeOp(view, operands, view.size());
} else if (auto reshape = dyn_cast<AtenReshapeOp>(op)) {
return visitReshapeLikeOp(reshape, operands, reshape.shape());
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) { } else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
return visitReshapeLikeOp(resize, operands); return visitReshapeLikeOp(resize, operands, resize.size());
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) { } else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
return visitAtenTransposeIntOp(transposeInt, operands); return visitAtenTransposeIntOp(transposeInt, operands);
} else if (auto t = dyn_cast<AtenTOp>(op)) { } else if (auto t = dyn_cast<AtenTOp>(op)) {
@ -567,7 +569,8 @@ private:
template <typename OpTy> template <typename OpTy>
ChangeResult ChangeResult
visitReshapeLikeOp(OpTy op, visitReshapeLikeOp(OpTy op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands,
Value sizeList);
ChangeResult ChangeResult
visitAtenTransposeIntOp(AtenTransposeIntOp op, visitAtenTransposeIntOp(AtenTransposeIntOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -1375,13 +1378,14 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
// result tensor. // result tensor.
template <typename OpTy> template <typename OpTy>
ChangeResult TypeAnalyzer::visitReshapeLikeOp( ChangeResult TypeAnalyzer::visitReshapeLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
Value sizeList) {
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
auto knowledge = auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = input.dtype; knowledge.dtype = input.dtype;
fillInSizesGivenSizesList(knowledge, op.size()); fillInSizesGivenSizesList(knowledge, sizeList);
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }