mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten.squeeze.dim` op
This commit adds lowering of `aten.squeeze.dim` op into `linalg.TensorCollapseShape` op. Here, the dim(th) dimension of the input tensor is not supposed to be dynamic. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/473/head
parent
8130354c09
commit
5a47f92390
|
@ -119,3 +119,113 @@ class SqueezeBroadcastModule(torch.nn.Module):
|
||||||
def SqueezeModule_broadcast(module, tu: TestUtils):
|
def SqueezeModule_broadcast(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 3), tu.rand())
|
module.forward(tu.rand(4, 3), tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeDimStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 7], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.squeeze(a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: SqueezeDimStaticModule())
|
||||||
|
def SqueezeDimModule_static(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 7))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeDimDynamicModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, 1, 384, -1, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.squeeze(a, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: SqueezeDimDynamicModule())
|
||||||
|
def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(8, 1, 384, 12, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeDimNegDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, -1, 1, 384, -1, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.squeeze(a, -6)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: SqueezeDimNegDimModule())
|
||||||
|
def SqueezeDimModule_negDim(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeDimIdentityModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([4, 1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.squeeze(a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: SqueezeDimIdentityModule())
|
||||||
|
def SqueezeDimModule_identity(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 1, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeDimUnitDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.squeeze(a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: SqueezeDimUnitDimModule())
|
||||||
|
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1))
|
||||||
|
|
||||||
|
|
|
@ -1608,6 +1608,21 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
||||||
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
|
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::squeeze.dim : (Tensor, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)";
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
|
|
|
@ -2667,6 +2667,76 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
Value input = adaptor.self();
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
int64_t inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
if (inputRank == 0) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "zero input rank should have been handled by the folder");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
|
||||||
|
// TODO: Handle the case where the dim(th) dimension is dynamic.
|
||||||
|
if (inputType.isDynamicDim(dim)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
auto resultType =
|
||||||
|
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
int64_t resultRank = resultType.getRank();
|
||||||
|
|
||||||
|
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||||
|
// `aten.squeeze` will behave as an identity operation.
|
||||||
|
if (inputType.getDimSize(dim) != 1) {
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociationMap(resultRank);
|
||||||
|
bool alreadyCrossedSqueezedDim = false;
|
||||||
|
for (int i = 0; i != resultRank; i++) {
|
||||||
|
if (alreadyCrossedSqueezedDim) {
|
||||||
|
reassociationMap[i].push_back(i + 1);
|
||||||
|
} else {
|
||||||
|
reassociationMap[i].push_back(i);
|
||||||
|
if (dim != 0 && i != dim - 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
alreadyCrossedSqueezedDim = true;
|
||||||
|
if (dim == 0)
|
||||||
|
reassociationMap[0].push_back(1);
|
||||||
|
if (i == dim - 1)
|
||||||
|
reassociationMap[i].push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Note: In case the operand tensor type is of unit rank and is statically
|
||||||
|
// shaped with unit dimension, the `reassociationMap` will be empty and the
|
||||||
|
// input will be collapsed to a 0-D tensor.
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||||
|
op, resultType, input, reassociationMap);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
|
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -3565,6 +3635,8 @@ public:
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSqueezeOp>();
|
target.addIllegalOp<AtenSqueezeOp>();
|
||||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||||
|
patterns.add<ConvertAtenSqueezeDimOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenConv2dOp>();
|
target.addIllegalOp<AtenConv2dOp>();
|
||||||
|
|
|
@ -462,6 +462,18 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenSqueezeDimOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||||
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||||
|
return getOperand(0);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenDimOp
|
// AtenDimOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -89,11 +89,12 @@ public:
|
||||||
Operation *op = workList.pop_back_val();
|
Operation *op = workList.pop_back_val();
|
||||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||||
} else if (isa<AtenSqueezeOp, AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
||||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
||||||
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
|
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||||
AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenFill_ScalarOp,
|
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||||
AtenSliceTensorOp, AtenSelectIntOp>(op)) {
|
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
|
||||||
|
op)) {
|
||||||
// AtenContiguousOp might return a view, so this is conservatively
|
// AtenContiguousOp might return a view, so this is conservatively
|
||||||
// correct. We could potentially be more precise and identify the cases
|
// correct. We could potentially be more precise and identify the cases
|
||||||
// that it does not return a view and treat those as having value
|
// that it does not return a view and treat those as having value
|
||||||
|
|
|
@ -328,6 +328,8 @@ public:
|
||||||
return visitAtenFlattenUsingIntsOp(flatten, operands);
|
return visitAtenFlattenUsingIntsOp(flatten, operands);
|
||||||
} else if (auto squeeze = dyn_cast<AtenSqueezeOp>(op)) {
|
} else if (auto squeeze = dyn_cast<AtenSqueezeOp>(op)) {
|
||||||
return visitAtenSqueezeOp(squeeze, operands);
|
return visitAtenSqueezeOp(squeeze, operands);
|
||||||
|
} else if (auto squeezeDim = dyn_cast<AtenSqueezeDimOp>(op)) {
|
||||||
|
return visitAtenSqueezeDimOp(squeezeDim, operands);
|
||||||
} else if (auto unsqueeze = dyn_cast<AtenUnsqueezeOp>(op)) {
|
} else if (auto unsqueeze = dyn_cast<AtenUnsqueezeOp>(op)) {
|
||||||
return visitAtenUnsqueezeOp(unsqueeze, operands);
|
return visitAtenUnsqueezeOp(unsqueeze, operands);
|
||||||
} else if (auto arange = dyn_cast<AtenArangeOp>(op)) {
|
} else if (auto arange = dyn_cast<AtenArangeOp>(op)) {
|
||||||
|
@ -514,6 +516,9 @@ private:
|
||||||
visitAtenSqueezeOp(AtenSqueezeOp op,
|
visitAtenSqueezeOp(AtenSqueezeOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
ChangeResult
|
ChangeResult
|
||||||
|
visitAtenSqueezeDimOp(AtenSqueezeDimOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
ChangeResult
|
||||||
visitAtenUnsqueezeOp(AtenUnsqueezeOp op,
|
visitAtenUnsqueezeOp(AtenUnsqueezeOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
|
||||||
|
@ -1002,6 +1007,34 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
|
||||||
return resultLattice;
|
return resultLattice;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
|
||||||
|
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto operand = operands[0]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||||
|
knowledge.dtype = operand.dtype;
|
||||||
|
int64_t dim;
|
||||||
|
if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||||
|
int64_t inputRank = operand.sizes.size();
|
||||||
|
if (inputRank == 0) {
|
||||||
|
if (dim == -1 || dim == 0) {
|
||||||
|
knowledge.hasSizes = true;
|
||||||
|
}
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
// The dim value is allowed to be in the range `[-inputRank, inputRank)`.
|
||||||
|
if (dim < 0)
|
||||||
|
dim += inputRank;
|
||||||
|
if (0 <= dim && dim < inputRank && operand.sizes[dim] != kUnknownSize) {
|
||||||
|
knowledge.hasSizes = true;
|
||||||
|
knowledge.sizes = operand.sizes;
|
||||||
|
if (operand.sizes[dim] == 1)
|
||||||
|
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
||||||
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto operand = operands[0]->getValue();
|
auto operand = operands[0]->getValue();
|
||||||
|
|
|
@ -527,6 +527,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
|
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
||||||
|
|
|
@ -613,3 +613,12 @@ func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tenso
|
||||||
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
||||||
return %0 : !torch.tensor<[],f32>
|
return %0 : !torch.tensor<[],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.squeeze.dim$zero_rank(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||||
|
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
||||||
|
func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
|
||||||
|
return %0 : !torch.tensor<[],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue