mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten.squeeze.dim` op
This commit adds lowering of `aten.squeeze.dim` op into `linalg.TensorCollapseShape` op. Here, the dim(th) dimension of the input tensor is not supposed to be dynamic. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/473/head
parent
8130354c09
commit
5a47f92390
|
@ -119,3 +119,113 @@ class SqueezeBroadcastModule(torch.nn.Module):
|
|||
def SqueezeModule_broadcast(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 7], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimStaticModule())
|
||||
def SqueezeDimModule_static(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, 1, 384, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 4)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimDynamicModule())
|
||||
def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 1, 384, 12, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -1, 1, 384, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, -6)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimNegDimModule())
|
||||
def SqueezeDimModule_negDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimIdentityModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimIdentityModule())
|
||||
def SqueezeDimModule_identity(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 1, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimUnitDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimUnitDimModule())
|
||||
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
||||
|
|
|
@ -1608,6 +1608,21 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
|
|||
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
|
||||
}
|
||||
|
||||
def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::squeeze.dim : (Tensor, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -2667,6 +2667,76 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Value input = adaptor.self();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
|
||||
if (inputRank == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "zero input rank should have been handled by the folder");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
// TODO: Handle the case where the dim(th) dimension is dynamic.
|
||||
if (inputType.isDynamicDim(dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
|
||||
}
|
||||
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int64_t resultRank = resultType.getRank();
|
||||
|
||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||
// `aten.squeeze` will behave as an identity operation.
|
||||
if (inputType.getDimSize(dim) != 1) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ReassociationIndices> reassociationMap(resultRank);
|
||||
bool alreadyCrossedSqueezedDim = false;
|
||||
for (int i = 0; i != resultRank; i++) {
|
||||
if (alreadyCrossedSqueezedDim) {
|
||||
reassociationMap[i].push_back(i + 1);
|
||||
} else {
|
||||
reassociationMap[i].push_back(i);
|
||||
if (dim != 0 && i != dim - 1)
|
||||
continue;
|
||||
|
||||
alreadyCrossedSqueezedDim = true;
|
||||
if (dim == 0)
|
||||
reassociationMap[0].push_back(1);
|
||||
if (i == dim - 1)
|
||||
reassociationMap[i].push_back(dim);
|
||||
}
|
||||
}
|
||||
// Note: In case the operand tensor type is of unit rank and is statically
|
||||
// shaped with unit dimension, the `reassociationMap` will be empty and the
|
||||
// input will be collapsed to a 0-D tensor.
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
|
||||
op, resultType, input, reassociationMap);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
|
||||
public:
|
||||
|
@ -3565,6 +3635,8 @@ public:
|
|||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||
patterns.add<ConvertAtenSqueezeDimOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
|
|
|
@ -462,6 +462,18 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSqueezeDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand(0);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -89,11 +89,12 @@ public:
|
|||
Operation *op = workList.pop_back_val();
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenSqueezeOp, AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
|
||||
AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenFill_ScalarOp,
|
||||
AtenSliceTensorOp, AtenSelectIntOp>(op)) {
|
||||
} else if (isa<AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp,
|
||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
|
||||
op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
|
|
|
@ -328,6 +328,8 @@ public:
|
|||
return visitAtenFlattenUsingIntsOp(flatten, operands);
|
||||
} else if (auto squeeze = dyn_cast<AtenSqueezeOp>(op)) {
|
||||
return visitAtenSqueezeOp(squeeze, operands);
|
||||
} else if (auto squeezeDim = dyn_cast<AtenSqueezeDimOp>(op)) {
|
||||
return visitAtenSqueezeDimOp(squeezeDim, operands);
|
||||
} else if (auto unsqueeze = dyn_cast<AtenUnsqueezeOp>(op)) {
|
||||
return visitAtenUnsqueezeOp(unsqueeze, operands);
|
||||
} else if (auto arange = dyn_cast<AtenArangeOp>(op)) {
|
||||
|
@ -514,6 +516,9 @@ private:
|
|||
visitAtenSqueezeOp(AtenSqueezeOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenSqueezeDimOp(AtenSqueezeDimOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenUnsqueezeOp(AtenUnsqueezeOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
|
@ -1002,6 +1007,34 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
|
|||
return resultLattice;
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
|
||||
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto operand = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = operand.dtype;
|
||||
int64_t dim;
|
||||
if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
||||
int64_t inputRank = operand.sizes.size();
|
||||
if (inputRank == 0) {
|
||||
if (dim == -1 || dim == 0) {
|
||||
knowledge.hasSizes = true;
|
||||
}
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
// The dim value is allowed to be in the range `[-inputRank, inputRank)`.
|
||||
if (dim < 0)
|
||||
dim += inputRank;
|
||||
if (0 <= dim && dim < inputRank && operand.sizes[dim] != kUnknownSize) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes = operand.sizes;
|
||||
if (operand.sizes[dim] == 1)
|
||||
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||
}
|
||||
}
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
|
||||
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto operand = operands[0]->getValue();
|
||||
|
|
|
@ -527,6 +527,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
|
||||
|
|
|
@ -613,3 +613,12 @@ func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tenso
|
|||
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.squeeze.dim$zero_rank(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
||||
func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue