diff --git a/e2e_testing/torchscript/squeeze.py b/e2e_testing/torchscript/squeeze.py index 27da26557..62b3c7a4a 100644 --- a/e2e_testing/torchscript/squeeze.py +++ b/e2e_testing/torchscript/squeeze.py @@ -119,3 +119,113 @@ class SqueezeBroadcastModule(torch.nn.Module): def SqueezeModule_broadcast(module, tu: TestUtils): module.forward(tu.rand(4, 3), tu.rand()) + +# ============================================================================== + + +class SqueezeDimStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 7], torch.float32, True), + ]) + def forward(self, a): + return torch.squeeze(a, 0) + + +@register_test_case( + module_factory=lambda: SqueezeDimStaticModule()) +def SqueezeDimModule_static(module, tu: TestUtils): + module.forward(tu.rand(1, 7)) + + +# ============================================================================== + + +class SqueezeDimDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1, 384, -1, 1], torch.float32, True), + ]) + def forward(self, a): + return torch.squeeze(a, 4) + + +@register_test_case( + module_factory=lambda: SqueezeDimDynamicModule()) +def SqueezeDimModule_dynamic(module, tu: TestUtils): + module.forward(tu.rand(8, 1, 384, 12, 1)) + + +# ============================================================================== + + +class SqueezeDimNegDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, -1, 1, 384, -1, 1], torch.float32, True), + ]) + def forward(self, a): + return torch.squeeze(a, -6) + + +@register_test_case( + module_factory=lambda: SqueezeDimNegDimModule()) +def SqueezeDimModule_negDim(module, tu: TestUtils): + module.forward(tu.rand(1, 8, 1, 384, 12, 1)) + + +# ============================================================================== + + +class SqueezeDimIdentityModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.squeeze(a, 0) + + +@register_test_case( + module_factory=lambda: SqueezeDimIdentityModule()) +def SqueezeDimModule_identity(module, tu: TestUtils): + module.forward(tu.rand(4, 1, 3)) + + +# ============================================================================== + + +class SqueezeDimUnitDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1], torch.float32, True), + ]) + def forward(self, a): + return torch.squeeze(a, 0) + + +@register_test_case( + module_factory=lambda: SqueezeDimUnitDimModule()) +def SqueezeDimModule_unitDim(module, tu: TestUtils): + module.forward(tu.rand(1)) + diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 380d07e26..0feac941a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1608,6 +1608,21 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)"; } +def Torch_AtenSqueezeDimOp : Torch_Op<"aten.squeeze.dim", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::squeeze.dim : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $dim attr-dict `:` type($self) `,` type($dim) `->` type($result)"; + let hasFolder = 1; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement ]> { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index b5ea1232f..4b81a9147 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -2667,6 +2667,76 @@ public: }; } // namespace +namespace { +class ConvertAtenSqueezeDimOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Value input = adaptor.self(); + auto inputType = input.getType().cast(); + int64_t inputRank = inputType.getRank(); + + if (inputRank == 0) { + return rewriter.notifyMatchFailure( + op, "zero input rank should have been handled by the folder"); + } + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be constant"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + // TODO: Handle the case where the dim(th) dimension is dynamic. + if (inputType.isDynamicDim(dim)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: dim(th) dimension is not expected to be dynamic"); + } + + TypeConverter *typeConverter = getTypeConverter(); + auto resultType = + typeConverter->convertType(op.getType()).cast(); + int64_t resultRank = resultType.getRank(); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // `aten.squeeze` will behave as an identity operation. + if (inputType.getDimSize(dim) != 1) { + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); + } + + SmallVector reassociationMap(resultRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != resultRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + rewriter.replaceOpWithNewOp( + op, resultType, input, reassociationMap); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenUnsqueezeOp : public OpConversionPattern { public: @@ -3565,6 +3635,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index a49947385..76e9d8da0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -462,6 +462,18 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenSqueezeDimOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSqueezeDimOp::fold(ArrayRef operands) { + if (auto tensorType = getOperand(0).getType().dyn_cast()) { + if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenDimOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index be25318ee..8bcff35c2 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -89,11 +89,12 @@ public: Operation *op = workList.pop_back_val(); if (auto copyToValueTensor = dyn_cast(op)) { copyToValueTensorOps.push_back(copyToValueTensor); - } else if (isa(op)) { + } else if (isa( + op)) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 2d2b0e941..8491ff79a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -328,6 +328,8 @@ public: return visitAtenFlattenUsingIntsOp(flatten, operands); } else if (auto squeeze = dyn_cast(op)) { return visitAtenSqueezeOp(squeeze, operands); + } else if (auto squeezeDim = dyn_cast(op)) { + return visitAtenSqueezeDimOp(squeezeDim, operands); } else if (auto unsqueeze = dyn_cast(op)) { return visitAtenUnsqueezeOp(unsqueeze, operands); } else if (auto arange = dyn_cast(op)) { @@ -514,6 +516,9 @@ private: visitAtenSqueezeOp(AtenSqueezeOp op, ArrayRef *> operands); ChangeResult + visitAtenSqueezeDimOp(AtenSqueezeDimOp op, + ArrayRef *> operands); + ChangeResult visitAtenUnsqueezeOp(AtenUnsqueezeOp op, ArrayRef *> operands); @@ -1002,6 +1007,34 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp( return resultLattice; } +ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp( + AtenSqueezeDimOp op, ArrayRef *> operands) { + auto operand = operands[0]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + knowledge.dtype = operand.dtype; + int64_t dim; + if (operand.hasSizes && matchPattern(op.dim(), m_TorchConstantInt(&dim))) { + int64_t inputRank = operand.sizes.size(); + if (inputRank == 0) { + if (dim == -1 || dim == 0) { + knowledge.hasSizes = true; + } + return getLatticeElement(op.getResult()).join(knowledge); + } + // The dim value is allowed to be in the range `[-inputRank, inputRank)`. + if (dim < 0) + dim += inputRank; + if (0 <= dim && dim < inputRank && operand.sizes[dim] != kUnknownSize) { + knowledge.hasSizes = true; + knowledge.sizes = operand.sizes; + if (operand.sizes[dim] == 1) + knowledge.sizes.erase(knowledge.sizes.begin() + dim); + } + } + return getLatticeElement(op.getResult()).join(knowledge); +} + ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp( AtenUnsqueezeOp op, ArrayRef *> operands) { auto operand = operands[0]->getValue(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 1db64c282..98ba5c72b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -527,6 +527,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") # Misc tensor ops. + emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 353238240..dd426e20c 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -613,3 +613,12 @@ func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tenso %0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32> return %0 : !torch.tensor<[],f32> } + +// CHECK-LABEL: func @torch.aten.squeeze.dim$zero_rank( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32> +func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32> + return %0 : !torch.tensor<[],f32> +}