[TORCH][MLIR] Add E2E support for `aten.squeeze.dim` op

This commit adds lowering of `aten.squeeze.dim` op into
`linalg.TensorCollapseShape` op. Here, the dim(th) dimension of the
input tensor is not supposed to be dynamic.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/473/head
Gaurav Shukla 2021-11-30 20:20:55 +05:30 committed by Gaurav Shukla
parent 8130354c09
commit 5a47f92390
8 changed files with 258 additions and 5 deletions

View File

@ -119,3 +119,113 @@ class SqueezeBroadcastModule(torch.nn.Module):
def SqueezeModule_broadcast(module, tu: TestUtils): def SqueezeModule_broadcast(module, tu: TestUtils):
module.forward(tu.rand(4, 3), tu.rand()) module.forward(tu.rand(4, 3), tu.rand())
# ==============================================================================
class SqueezeDimStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 7], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, 0)
@register_test_case(
module_factory=lambda: SqueezeDimStaticModule())
def SqueezeDimModule_static(module, tu: TestUtils):
module.forward(tu.rand(1, 7))
# ==============================================================================
class SqueezeDimDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, 1, 384, -1, 1], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, 4)
@register_test_case(
module_factory=lambda: SqueezeDimDynamicModule())
def SqueezeDimModule_dynamic(module, tu: TestUtils):
module.forward(tu.rand(8, 1, 384, 12, 1))
# ==============================================================================
class SqueezeDimNegDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, -1, 1, 384, -1, 1], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, -6)
@register_test_case(
module_factory=lambda: SqueezeDimNegDimModule())
def SqueezeDimModule_negDim(module, tu: TestUtils):
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
# ==============================================================================
class SqueezeDimIdentityModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 1, -1], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, 0)
@register_test_case(
module_factory=lambda: SqueezeDimIdentityModule())
def SqueezeDimModule_identity(module, tu: TestUtils):
module.forward(tu.rand(4, 1, 3))
# ==============================================================================
class SqueezeDimUnitDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, 0)
@register_test_case(
module_factory=lambda: SqueezeDimUnitDimModule())
def SqueezeDimModule_unitDim(module, tu: TestUtils):
module.forward(tu.rand(1))

View File

@ -1608,6 +1608,21 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)"; let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
} }
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::squeeze.dim : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)";
let hasFolder = 1;
}
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {

View File

@ -2667,6 +2667,76 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value input = adaptor.self();
auto inputType = input.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
if (inputRank == 0) {
return rewriter.notifyMatchFailure(
op, "zero input rank should have been handled by the folder");
}
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
// TODO: Handle the case where the dim(th) dimension is dynamic.
if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
}
TypeConverter *typeConverter = getTypeConverter();
auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
int64_t resultRank = resultType.getRank();
// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
}
SmallVector<ReassociationIndices> reassociationMap(resultRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != resultRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
reassociationMap[i].push_back(i);
if (dim != 0 && i != dim - 1)
continue;
alreadyCrossedSqueezedDim = true;
if (dim == 0)
reassociationMap[0].push_back(1);
if (i == dim - 1)
reassociationMap[i].push_back(dim);
}
}
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
op, resultType, input, reassociationMap);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> { class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
public: public:
@ -3565,6 +3635,8 @@ public:
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context); patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeDimOp>();
patterns.add<ConvertAtenSqueezeDimOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>(); target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context); patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenConv2dOp>(); target.addIllegalOp<AtenConv2dOp>();

View File

@ -462,6 +462,18 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenSqueezeDimOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenDimOp // AtenDimOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -89,11 +89,12 @@ public:
Operation *op = workList.pop_back_val(); Operation *op = workList.pop_back_val();
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) { if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor); copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenSqueezeOp, AtenUnsqueezeOp, AtenFlattenUsingIntsOp, } else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
AtenTransposeIntOp, TensorStaticInfoCastOp, AtenFlattenUsingIntsOp, AtenTransposeIntOp,
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp, TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenFill_ScalarOp, AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
AtenSliceTensorOp, AtenSelectIntOp>(op)) { AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
op)) {
// AtenContiguousOp might return a view, so this is conservatively // AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases // correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value // that it does not return a view and treat those as having value

View File

@ -328,6 +328,8 @@ public:
return visitAtenFlattenUsingIntsOp(flatten, operands); return visitAtenFlattenUsingIntsOp(flatten, operands);
} else if (auto squeeze = dyn_cast<AtenSqueezeOp>(op)) { } else if (auto squeeze = dyn_cast<AtenSqueezeOp>(op)) {
return visitAtenSqueezeOp(squeeze, operands); return visitAtenSqueezeOp(squeeze, operands);
} else if (auto squeezeDim = dyn_cast<AtenSqueezeDimOp>(op)) {
return visitAtenSqueezeDimOp(squeezeDim, operands);
} else if (auto unsqueeze = dyn_cast<AtenUnsqueezeOp>(op)) { } else if (auto unsqueeze = dyn_cast<AtenUnsqueezeOp>(op)) {
return visitAtenUnsqueezeOp(unsqueeze, operands); return visitAtenUnsqueezeOp(unsqueeze, operands);
} else if (auto arange = dyn_cast<AtenArangeOp>(op)) { } else if (auto arange = dyn_cast<AtenArangeOp>(op)) {
@ -514,6 +516,9 @@ private:
visitAtenSqueezeOp(AtenSqueezeOp op, visitAtenSqueezeOp(AtenSqueezeOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult ChangeResult
visitAtenSqueezeDimOp(AtenSqueezeDimOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenUnsqueezeOp(AtenUnsqueezeOp op, visitAtenUnsqueezeOp(AtenUnsqueezeOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -1002,6 +1007,34 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
return resultLattice; return resultLattice;
} }
ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto operand = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = operand.dtype;
int64_t dim;
if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
int64_t inputRank = operand.sizes.size();
if (inputRank == 0) {
if (dim == -1 || dim == 0) {
knowledge.hasSizes = true;
}
return getLatticeElement(op.getResult()).join(knowledge);
}
// The dim value is allowed to be in the range `[-inputRank, inputRank)`.
if (dim < 0)
dim += inputRank;
if (0 <= dim && dim < inputRank && operand.sizes[dim] != kUnknownSize) {
knowledge.hasSizes = true;
knowledge.sizes = operand.sizes;
if (operand.sizes[dim] == 1)
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
}
}
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp( ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto operand = operands[0]->getValue(); auto operand = operands[0]->getValue();

View File

@ -527,6 +527,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
# Misc tensor ops. # Misc tensor ops.
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")

View File

@ -613,3 +613,12 @@ func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tenso
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32> %0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32> return %0 : !torch.tensor<[],f32>
} }
// CHECK-LABEL: func @torch.aten.squeeze.dim$zero_rank(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
}