mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten.reshape` op
This commit decomposes `aten.reshape` into `aten.view` op in the case of value tensor type operand. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/463/head snapshot-20220202.244
parent
1b505cbac5
commit
0079901039
|
@ -39,7 +39,7 @@ from . import backprop
|
|||
from . import reduction
|
||||
from . import argmax
|
||||
from . import matmul
|
||||
from . import view
|
||||
from . import reshape_like
|
||||
from . import scalar
|
||||
from . import squeeze
|
||||
from . import slice_like
|
||||
|
|
|
@ -123,3 +123,41 @@ class View1DFoldModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: View1DFoldModule())
|
||||
def View1DFoldModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReshapeExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.reshape(12, 32)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeExpandModule())
|
||||
def ReshapeExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(384))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReshapeCollapseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.reshape(a, (-1,))
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeCollapseModule())
|
||||
def ReshapeCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
|
@ -146,6 +146,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.self();
|
||||
// TODO: Handle non value tensor type operands.
|
||||
if (!input.getType().isa<ValueTensorType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only value tensor type operands are supported");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
|
||||
op.shape());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||
// exp(x)/sum(exp(x)).
|
||||
template <typename OpTy>
|
||||
|
@ -784,6 +803,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenExpandOp>();
|
||||
patterns.add<DecomposeAtenSizeOp>(context);
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
patterns.add<DecomposeAtenReshapeOp>(context);
|
||||
target.addIllegalOp<AtenReshapeOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
||||
|
|
|
@ -90,7 +90,7 @@ public:
|
|||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenReshapeOp,
|
||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
||||
|
|
|
@ -365,9 +365,11 @@ public:
|
|||
maxDim, maxDim.dim(), maxDim.keepdim(),
|
||||
secondResDtype, operands, /*resNum=*/1);
|
||||
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
||||
return visitReshapeLikeOp(view, operands);
|
||||
return visitReshapeLikeOp(view, operands, view.size());
|
||||
} else if (auto reshape = dyn_cast<AtenReshapeOp>(op)) {
|
||||
return visitReshapeLikeOp(reshape, operands, reshape.shape());
|
||||
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
||||
return visitReshapeLikeOp(resize, operands);
|
||||
return visitReshapeLikeOp(resize, operands, resize.size());
|
||||
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
|
||||
return visitAtenTransposeIntOp(transposeInt, operands);
|
||||
} else if (auto t = dyn_cast<AtenTOp>(op)) {
|
||||
|
@ -567,7 +569,8 @@ private:
|
|||
template <typename OpTy>
|
||||
ChangeResult
|
||||
visitReshapeLikeOp(OpTy op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
Value sizeList);
|
||||
ChangeResult
|
||||
visitAtenTransposeIntOp(AtenTransposeIntOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
@ -1375,13 +1378,14 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
|||
// result tensor.
|
||||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitReshapeLikeOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
Value sizeList) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
|
||||
fillInSizesGivenSizesList(knowledge, op.size());
|
||||
fillInSizesGivenSizesList(knowledge, sizeList);
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue