From a52aded0b90918fec8f6367726524a39b0aacfc3 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 2 Dec 2021 22:09:21 -0600 Subject: [PATCH] Add lowering for slice and selectInt (#398) --- e2e_testing/torchscript/basic.py | 45 +++- e2e_testing/torchscript/main.py | 1 + e2e_testing/torchscript/slice_like.py | 227 ++++++++++++++++++ e2e_testing/torchscript/xfail_sets.py | 9 +- .../TorchToLinalg/TorchToLinalg.cpp | 103 ++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 22 ++ .../Transforms/MaximizeValueSemantics.cpp | 4 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 16 +- test/Dialect/Torch/refine-types.mlir | 4 +- 9 files changed, 416 insertions(+), 15 deletions(-) create mode 100644 e2e_testing/torchscript/slice_like.py diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 00d367ddd..ea43fc4c4 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -100,6 +100,8 @@ class MmTanhModule(torch.nn.Module): def matmul(self, lhs, rhs): return torch.mm(lhs, rhs) +# ============================================================================== + @register_test_case(module_factory=lambda: MmTanhModule()) def MmTanhModule_basic(module, tu: TestUtils): @@ -192,6 +194,8 @@ class AdaptiveAvgPool2dModule(torch.nn.Module): def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9)) +# ============================================================================== + class FlattenStaticModule(torch.nn.Module): def __init__(self): @@ -211,6 +215,8 @@ class FlattenStaticModule(torch.nn.Module): def FlattenStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9, 3, 4)) +# ============================================================================== + class FlattenRank0Module(torch.nn.Module): def __init__(self): @@ -230,6 +236,8 @@ class FlattenRank0Module(torch.nn.Module): def FlattenRank0Module_basic(module, tu: TestUtils): module.forward(torch.tensor(4.0)) +# ============================================================================== + class FlattenDynamicModule(torch.nn.Module): def __init__(self): @@ -249,6 +257,8 @@ class FlattenDynamicModule(torch.nn.Module): def FlattenDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9, 3, 4)) +# ============================================================================== + class MaxPool2dModule(torch.nn.Module): def __init__(self): @@ -266,6 +276,8 @@ class MaxPool2dModule(torch.nn.Module): def forward(self, x): return self.mp2d(x) +# ============================================================================== + @register_test_case(module_factory=lambda: MaxPool2dModule()) def MaxPool2dModule_basic(module, tu: TestUtils): @@ -284,6 +296,8 @@ class TransposeIntModule(torch.nn.Module): def forward(self, x): return torch.transpose(x, 0, 1) +# ============================================================================== + @register_test_case(module_factory=lambda: TransposeIntModule()) def TransposeIntModule_basic(module, tu: TestUtils): @@ -305,6 +319,8 @@ class PermuteModule(torch.nn.Module): def PermuteModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) +# ============================================================================== + class TransposeIntNegDimsModule(torch.nn.Module): def __init__(self): super().__init__() @@ -317,6 +333,8 @@ class TransposeIntNegDimsModule(torch.nn.Module): def forward(self, x): return torch.transpose(x, -1, -2) +# ============================================================================== + @register_test_case(module_factory=lambda: TransposeIntNegDimsModule()) def TransposeIntNegDimsModule_basic(module, tu: TestUtils): @@ -335,6 +353,8 @@ class PermuteNegativeIndexModule(torch.nn.Module): def forward(self, x): return x.permute(0, -1, 1) +# ============================================================================== + @register_test_case(module_factory=lambda: PermuteNegativeIndexModule()) def PermuteNegativeIndexModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) @@ -357,6 +377,8 @@ class TensorsConcatModule(torch.nn.Module): def TensorsConcatModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4)) +# ============================================================================== + class GatherModule(torch.nn.Module): def __init__(self): @@ -376,6 +398,8 @@ class GatherModule(torch.nn.Module): def GatherModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]])) +# ============================================================================== + class AddSizeIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -396,6 +420,8 @@ class AddSizeIntModule(torch.nn.Module): def AddSizeIntModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3)) +# ============================================================================== + class AddSizeIntNegDimModule(torch.nn.Module): def __init__(self): @@ -417,6 +443,8 @@ class AddSizeIntNegDimModule(torch.nn.Module): def AddSizeIntNegDimModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3)) +# ============================================================================== + class EmbeddingModule(torch.nn.Module): def __init__(self): super().__init__() @@ -438,6 +466,7 @@ class EmbeddingModule(torch.nn.Module): def EmbeddingModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (3, 3))) +# ============================================================================== class SoftmaxIntModule(torch.nn.Module): def __init__(self): @@ -474,6 +503,8 @@ class _SoftmaxModule(torch.nn.Module): def _SoftmaxModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) +# ============================================================================== + class SoftmaxIntNegDimModule(torch.nn.Module): def __init__(self): @@ -494,6 +525,8 @@ class SoftmaxIntNegDimModule(torch.nn.Module): def SoftmaxIntNegDimModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) +# ============================================================================== + class SoftmaxIntArgTypeF64Module(torch.nn.Module): def __init__(self): @@ -513,6 +546,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module): def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).double()) +# ============================================================================== class BroadcastToModule(torch.nn.Module): def __init__(self): @@ -531,6 +565,8 @@ class BroadcastToModule(torch.nn.Module): def BroadcastToModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 1)) +# ============================================================================== + class ExpandModule(torch.nn.Module): def __init__(self): super().__init__() @@ -548,6 +584,9 @@ class ExpandModule(torch.nn.Module): def ExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 1)) +# ============================================================================== + + class OnesModuleInt(torch.nn.Module): def __init__(self): super().__init__() @@ -563,6 +602,8 @@ class OnesModuleInt(torch.nn.Module): def OnesModuleInt_basic(module, tu: TestUtils): module.forward() +# ============================================================================== + class OnesModuleFloat(torch.nn.Module): def __init__(self): super().__init__() @@ -594,6 +635,7 @@ class OnesModuleFalsePinMemory(torch.nn.Module): def OnesModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ContiguousModule(torch.nn.Module): def __init__(self): @@ -611,7 +653,7 @@ class ContiguousModule(torch.nn.Module): @register_test_case(module_factory=lambda: ContiguousModule()) def ContiguousModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1)) - + class TensorToInt(torch.nn.Module): def __init__(self): super().__init__() @@ -681,6 +723,7 @@ class NumToTensorFloatModule(torch.nn.Module): def NumToTensorFloatModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== # This test can be removed once we have one real op returning 3 float32 tensors class ReturnThreeTensorFloat32(torch.nn.Module): diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 83ff3dbc2..39ea40243 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -42,6 +42,7 @@ from . import matmul from . import view from . import scalar from . import squeeze +from . import slice_like def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/e2e_testing/torchscript/slice_like.py b/e2e_testing/torchscript/slice_like.py new file mode 100644 index 000000000..ecc3cdd64 --- /dev/null +++ b/e2e_testing/torchscript/slice_like.py @@ -0,0 +1,227 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class SliceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[0:5:1, 1:3:1, 2:4:1] + + +@register_test_case(module_factory=lambda: SliceModule()) +def SliceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + + + +# ============================================================================== + +# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448 +class SliceOutOfUpperBoundIndexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[:8, :5, 8:] + + +@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule()) +def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + +class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[:-8,-7:,:] + + +@register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule()) +def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + +class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[-8:3:1, 1:3:1, 2:4:1] + + +@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule()) +def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + +# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448 +class SliceEndSleStartModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[:0, 4:3, :-7] + + +@register_test_case(module_factory=lambda: SliceEndSleStartModule()) +def SliceEndSleStartModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + +# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448 +class SliceStartEqEndModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[5:5, 3:3, -1:] + + +@register_test_case(module_factory=lambda: SliceStartEqEndModule()) +def SliceStartEqEndModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4,7)) + +# ============================================================================== + +class SliceSizeTwoStepModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[0:5:2, 0:3:2, 0:4:2] + + +@register_test_case(module_factory=lambda: SliceSizeTwoStepModule()) +def SliceSizeTwoStepModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10,5,17)) + +# ============================================================================== + +class SliceNegIdxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[:-1, -2:-1] + + +@register_test_case(module_factory=lambda: SliceNegIdxModule()) +def SliceNegIdxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3,9)) + +# ============================================================================== + +class SliceSingleIdxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[0] + + +@register_test_case(module_factory=lambda: SliceSingleIdxModule()) +def SliceSingleIdxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,8)) + +# ============================================================================== + +class SliceWholeTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return x[:, :] + + +@register_test_case(module_factory=lambda: SliceWholeTensorModule()) +def SliceWholeTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,8)) + +# ============================================================================== + +class SelectIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, x): + return x.select(0,0) + + +@register_test_case(module_factory=lambda: SelectIntModule()) +def SelectIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (5,5))) + +# ============================================================================== + diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 42b317a63..1c339d6fe 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -17,8 +17,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "IouOfModule_basic", } - -REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS +# Fails due to https://github.com/llvm/torch-mlir/issues/448 +SIZE_ZERO_TENSOR_XFAILS = { + "SliceEndSleStartModule_basic", + "SliceStartEqEndModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", +} +REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS) # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 39e216dd1..f4a85084e 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -2712,6 +2712,107 @@ public: }; } // namespace +namespace { +class ConvertAtenSliceTensorOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSliceTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + TypeConverter *typeConverter = getTypeConverter(); + + auto input = adaptor.self(); + RankedTensorType inputType = input.getType().cast(); + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + int64_t resultRank = resultType.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return op->emitError("unimplemented: dim is not constant"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value dimSize = inputShape[dim]; + + auto adjustStartOrEnd = [&](Value startOrEndTorchType, + Value startOrEndBuiltin, Value valueForNone) { + if (startOrEndTorchType.getType().isa()) + return valueForNone; + auto dimSizeAsInt = castIndexToInt(rewriter, loc, dimSize); + Value startOrEndToPositive = + toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt); + // startOrEnd < 0 ? 0 : startOrEnd + Value cst0 = rewriter.create( + loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); + Value predDimSltZero = rewriter.create( + loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0); + Value startOrEndAtLeastZero = rewriter.create( + loc, predDimSltZero, cst0, startOrEndToPositive); + // startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd + Value startOrEndSgtDimSize = rewriter.create( + loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt); + Value startOrEndBoundedByDimSize = rewriter.create( + loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero); + + return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize); + }; + + Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero); + Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize); + + int64_t step; + if (!matchPattern(op.step(), m_TorchConstantInt(&step))) { + if (!op.step().getType().isa()) + return op->emitError("unimplemented: step is not constant"); + step = 1; + } + + // Slice logic: resultSize = floordiv(end - start + step - 1, step) + Value stepIndex = rewriter.create(loc, step); + Value len = rewriter.create(loc, end, start); + Value resultSize = rewriter.create(loc, len, stepIndex); + resultSize = rewriter.create(loc, resultSize, one); + resultSize = + rewriter.create(loc, resultSize, stepIndex); + + SmallVector resultShape = getTensorSizes(rewriter, loc, input); + resultShape[dim] = resultSize; + + SmallVector offsets(inputType.getRank(), zero); + SmallVector strides(inputType.getRank(), one); + offsets[dim] = start; + strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + + Value result = rewriter.create( + loc, input, offsets, resultShape, strides); + + // TODO: This code is for selectOp, remove once squeeze dim is added + if (resultRank < inputType.getRank()) { + SmallVector reassociation(resultRank); + int64_t resultIdx = 0; + for (auto i : llvm::seq(0, inputType.getRank())) { + if (resultIdx < resultRank) + reassociation[resultIdx].push_back(i); + if (i != dim) + resultIdx++; + } + result = rewriter.create(loc, result, + reassociation); + } + rewriter.replaceOpWithNewOp(op, resultType, result); + + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCatOp : public OpConversionPattern { public: @@ -3265,6 +3366,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5cf17340b..f29f9394e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -126,6 +126,26 @@ public: }; } // namespace +namespace { +class DecomposeAtenSelectIntOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSelectIntOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value end = + rewriter.create(loc, one.getType(), op.index(), one); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + op.self(), op.dim(), + op.index(), end, one); + + return success(); + } +}; +} // namespace + // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). template @@ -487,6 +507,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); patterns.add(context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 8702f9fc6..be25318ee 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -92,8 +92,8 @@ public: } else if (isa(op)) { + AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenFill_ScalarOp, + AtenSliceTensorOp, AtenSelectIntOp>(op)) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 03f84c85e..bc7f61138 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -390,13 +390,11 @@ public: }; return visitSliceLikeOp(indexSelect, operands, setDim); } else if (auto selectInt = dyn_cast(op)) { - // Select one element from the target dim. All the other dims are the same - // as input. + // Slices along dim at index. Result shape same as input except dim is + // removed. auto setDim = [](int64_t &targetDim, int64_t dim, - ArrayRef *> operands) { - targetDim = 1; - }; - return visitSliceLikeOp(selectInt, operands, setDim); + ArrayRef *> operands) {}; + return visitSliceLikeOp(selectInt, operands, setDim, /*keepDim=*/false); } else if (auto sliceTensor = dyn_cast(op)) { // Select several elements from the target dim according to the start, // end, step. All the other dims are the same as input. @@ -540,7 +538,7 @@ private: template ChangeResult visitSliceLikeOp(OpTy op, ArrayRef *> operands, - SetDimSizeFn setDim); + SetDimSizeFn setDim, bool keepDim = true); ChangeResult visitAtenGatherOp(AtenGatherOp op, ArrayRef *> operands); @@ -1222,7 +1220,7 @@ ChangeResult TypeAnalyzer::visitTypeConversionOp( template ChangeResult TypeAnalyzer::visitSliceLikeOp( OpTy op, ArrayRef *> operands, - SetDimSizeFn setDim) { + SetDimSizeFn setDim, bool keepDim) { auto input = operands[0]->getValue(); auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); @@ -1248,6 +1246,8 @@ ChangeResult TypeAnalyzer::visitSliceLikeOp( } knowledge.sizes = input.sizes; setDim(knowledge.sizes[dim], dim, operands); + if (!keepDim) + knowledge.sizes.erase(knowledge.sizes.begin() + dim); return getLatticeElement(op.getResult()).join(knowledge); } diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index ec11c38a6..6c475099b 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -750,8 +750,8 @@ builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4], // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[INDEX:.*]]: !torch.int) -> !torch.tensor { // CHECK: %[[DIM:.*]] = torch.constant.int 1 -// CHECK: %[[RET:.*]] = torch.aten.select.int %[[INPUT]], %[[DIM]], %[[INDEX]] : !torch.tensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.tensor<[2,1,4],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,1,4],f32> to !torch.tensor +// CHECK: %[[RET:.*]] = torch.aten.select.int %[[INPUT]], %[[DIM]], %[[INDEX]] : !torch.tensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.tensor<[2,4],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,4],f32> to !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.int) -> !torch.tensor {