mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten.reshape` op
This commit decomposes `aten.reshape` into `aten.view` op in the case of value tensor type operand. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/463/head snapshot-20220202.244
parent
1b505cbac5
commit
0079901039
|
@ -39,7 +39,7 @@ from . import backprop
|
||||||
from . import reduction
|
from . import reduction
|
||||||
from . import argmax
|
from . import argmax
|
||||||
from . import matmul
|
from . import matmul
|
||||||
from . import view
|
from . import reshape_like
|
||||||
from . import scalar
|
from . import scalar
|
||||||
from . import squeeze
|
from . import squeeze
|
||||||
from . import slice_like
|
from . import slice_like
|
||||||
|
|
|
@ -123,3 +123,41 @@ class View1DFoldModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: View1DFoldModule())
|
@register_test_case(module_factory=lambda: View1DFoldModule())
|
||||||
def View1DFoldModule_basic(module, tu: TestUtils):
|
def View1DFoldModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(32))
|
module.forward(tu.rand(32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReshapeExpandModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.reshape(12, 32)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReshapeExpandModule())
|
||||||
|
def ReshapeExpandModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(384))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReshapeCollapseModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.reshape(a, (-1,))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReshapeCollapseModule())
|
||||||
|
def ReshapeCollapseModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4))
|
|
@ -146,6 +146,25 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenReshapeOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value input = op.self();
|
||||||
|
// TODO: Handle non value tensor type operands.
|
||||||
|
if (!input.getType().isa<ValueTensorType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only value tensor type operands are supported");
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), input,
|
||||||
|
op.shape());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||||
// exp(x)/sum(exp(x)).
|
// exp(x)/sum(exp(x)).
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -784,6 +803,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenExpandOp>();
|
target.addIllegalOp<AtenExpandOp>();
|
||||||
patterns.add<DecomposeAtenSizeOp>(context);
|
patterns.add<DecomposeAtenSizeOp>(context);
|
||||||
target.addIllegalOp<AtenSizeOp>();
|
target.addIllegalOp<AtenSizeOp>();
|
||||||
|
patterns.add<DecomposeAtenReshapeOp>(context);
|
||||||
|
target.addIllegalOp<AtenReshapeOp>();
|
||||||
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||||
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
||||||
|
|
|
@ -90,7 +90,7 @@ public:
|
||||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||||
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
||||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
AtenFlattenUsingIntsOp, AtenTransposeIntOp, AtenReshapeOp,
|
||||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
||||||
|
|
|
@ -365,9 +365,11 @@ public:
|
||||||
maxDim, maxDim.dim(), maxDim.keepdim(),
|
maxDim, maxDim.dim(), maxDim.keepdim(),
|
||||||
secondResDtype, operands, /*resNum=*/1);
|
secondResDtype, operands, /*resNum=*/1);
|
||||||
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
} else if (auto view = dyn_cast<AtenViewOp>(op)) {
|
||||||
return visitReshapeLikeOp(view, operands);
|
return visitReshapeLikeOp(view, operands, view.size());
|
||||||
|
} else if (auto reshape = dyn_cast<AtenReshapeOp>(op)) {
|
||||||
|
return visitReshapeLikeOp(reshape, operands, reshape.shape());
|
||||||
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
} else if (auto resize = dyn_cast<AtenResize_Op>(op)) {
|
||||||
return visitReshapeLikeOp(resize, operands);
|
return visitReshapeLikeOp(resize, operands, resize.size());
|
||||||
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
|
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
|
||||||
return visitAtenTransposeIntOp(transposeInt, operands);
|
return visitAtenTransposeIntOp(transposeInt, operands);
|
||||||
} else if (auto t = dyn_cast<AtenTOp>(op)) {
|
} else if (auto t = dyn_cast<AtenTOp>(op)) {
|
||||||
|
@ -567,7 +569,8 @@ private:
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitReshapeLikeOp(OpTy op,
|
visitReshapeLikeOp(OpTy op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||||
|
Value sizeList);
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitAtenTransposeIntOp(AtenTransposeIntOp op,
|
visitAtenTransposeIntOp(AtenTransposeIntOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
@ -1375,13 +1378,14 @@ ChangeResult TypeAnalyzer::visitReductionAlongDimIntOp(
|
||||||
// result tensor.
|
// result tensor.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
ChangeResult TypeAnalyzer::visitReshapeLikeOp(
|
ChangeResult TypeAnalyzer::visitReshapeLikeOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||||
|
Value sizeList) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||||
knowledge.dtype = input.dtype;
|
knowledge.dtype = input.dtype;
|
||||||
|
|
||||||
fillInSizesGivenSizesList(knowledge, op.size());
|
fillInSizesGivenSizesList(knowledge, sizeList);
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue