Add lowering for slice and selectInt (#398)

pull/458/head snapshot-20211203.121
Daniel Garvey 2021-12-02 22:09:21 -06:00 committed by GitHub
parent 46a2189a41
commit a52aded0b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 416 additions and 15 deletions

View File

@ -100,6 +100,8 @@ class MmTanhModule(torch.nn.Module):
def matmul(self, lhs, rhs):
return torch.mm(lhs, rhs)
# ==============================================================================
@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
@ -192,6 +194,8 @@ class AdaptiveAvgPool2dModule(torch.nn.Module):
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9))
# ==============================================================================
class FlattenStaticModule(torch.nn.Module):
def __init__(self):
@ -211,6 +215,8 @@ class FlattenStaticModule(torch.nn.Module):
def FlattenStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
# ==============================================================================
class FlattenRank0Module(torch.nn.Module):
def __init__(self):
@ -230,6 +236,8 @@ class FlattenRank0Module(torch.nn.Module):
def FlattenRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(4.0))
# ==============================================================================
class FlattenDynamicModule(torch.nn.Module):
def __init__(self):
@ -249,6 +257,8 @@ class FlattenDynamicModule(torch.nn.Module):
def FlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
# ==============================================================================
class MaxPool2dModule(torch.nn.Module):
def __init__(self):
@ -266,6 +276,8 @@ class MaxPool2dModule(torch.nn.Module):
def forward(self, x):
return self.mp2d(x)
# ==============================================================================
@register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils):
@ -284,6 +296,8 @@ class TransposeIntModule(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, 0, 1)
# ==============================================================================
@register_test_case(module_factory=lambda: TransposeIntModule())
def TransposeIntModule_basic(module, tu: TestUtils):
@ -305,6 +319,8 @@ class PermuteModule(torch.nn.Module):
def PermuteModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
# ==============================================================================
class TransposeIntNegDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -317,6 +333,8 @@ class TransposeIntNegDimsModule(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
# ==============================================================================
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
@ -335,6 +353,8 @@ class PermuteNegativeIndexModule(torch.nn.Module):
def forward(self, x):
return x.permute(0, -1, 1)
# ==============================================================================
@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
@ -357,6 +377,8 @@ class TensorsConcatModule(torch.nn.Module):
def TensorsConcatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
# ==============================================================================
class GatherModule(torch.nn.Module):
def __init__(self):
@ -376,6 +398,8 @@ class GatherModule(torch.nn.Module):
def GatherModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
# ==============================================================================
class AddSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -396,6 +420,8 @@ class AddSizeIntModule(torch.nn.Module):
def AddSizeIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3))
# ==============================================================================
class AddSizeIntNegDimModule(torch.nn.Module):
def __init__(self):
@ -417,6 +443,8 @@ class AddSizeIntNegDimModule(torch.nn.Module):
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3))
# ==============================================================================
class EmbeddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -438,6 +466,7 @@ class EmbeddingModule(torch.nn.Module):
def EmbeddingModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 3)))
# ==============================================================================
class SoftmaxIntModule(torch.nn.Module):
def __init__(self):
@ -474,6 +503,8 @@ class _SoftmaxModule(torch.nn.Module):
def _SoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
# ==============================================================================
class SoftmaxIntNegDimModule(torch.nn.Module):
def __init__(self):
@ -494,6 +525,8 @@ class SoftmaxIntNegDimModule(torch.nn.Module):
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
# ==============================================================================
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def __init__(self):
@ -513,6 +546,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())
# ==============================================================================
class BroadcastToModule(torch.nn.Module):
def __init__(self):
@ -531,6 +565,8 @@ class BroadcastToModule(torch.nn.Module):
def BroadcastToModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))
# ==============================================================================
class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -548,6 +584,9 @@ class ExpandModule(torch.nn.Module):
def ExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))
# ==============================================================================
class OnesModuleInt(torch.nn.Module):
def __init__(self):
super().__init__()
@ -563,6 +602,8 @@ class OnesModuleInt(torch.nn.Module):
def OnesModuleInt_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class OnesModuleFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@ -594,6 +635,7 @@ class OnesModuleFalsePinMemory(torch.nn.Module):
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ContiguousModule(torch.nn.Module):
def __init__(self):
@ -681,6 +723,7 @@ class NumToTensorFloatModule(torch.nn.Module):
def NumToTensorFloatModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
# This test can be removed once we have one real op returning 3 float32 tensors
class ReturnThreeTensorFloat32(torch.nn.Module):

View File

@ -42,6 +42,7 @@ from . import matmul
from . import view
from . import scalar
from . import squeeze
from . import slice_like
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -0,0 +1,227 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class SliceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[0:5:1, 1:3:1, 2:4:1]
@register_test_case(module_factory=lambda: SliceModule())
def SliceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))
# ==============================================================================
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[:8, :5, 8:]
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))
# ==============================================================================
class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[:-8,-7:,:]
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))
# ==============================================================================
class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[-8:3:1, 1:3:1, 2:4:1]
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))
# ==============================================================================
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
class SliceEndSleStartModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[:0, 4:3, :-7]
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
def SliceEndSleStartModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))
# ==============================================================================
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
class SliceStartEqEndModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[5:5, 3:3, -1:]
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
def SliceStartEqEndModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4,7))
# ==============================================================================
class SliceSizeTwoStepModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[0:5:2, 0:3:2, 0:4:2]
@register_test_case(module_factory=lambda: SliceSizeTwoStepModule())
def SliceSizeTwoStepModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10,5,17))
# ==============================================================================
class SliceNegIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return x[:-1, -2:-1]
@register_test_case(module_factory=lambda: SliceNegIdxModule())
def SliceNegIdxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,9))
# ==============================================================================
class SliceSingleIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return x[0]
@register_test_case(module_factory=lambda: SliceSingleIdxModule())
def SliceSingleIdxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,8))
# ==============================================================================
class SliceWholeTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return x[:, :]
@register_test_case(module_factory=lambda: SliceWholeTensorModule())
def SliceWholeTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6,8))
# ==============================================================================
class SelectIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return x.select(0,0)
@register_test_case(module_factory=lambda: SelectIntModule())
def SelectIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (5,5)))
# ==============================================================================

View File

@ -17,8 +17,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"IouOfModule_basic",
}
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
# Fails due to https://github.com/llvm/torch-mlir/issues/448
SIZE_ZERO_TENSOR_XFAILS = {
"SliceEndSleStartModule_basic",
"SliceStartEqEndModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
}
REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS)
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.

View File

@ -2712,6 +2712,107 @@ public:
};
} // namespace
namespace {
class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.self();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
int64_t resultRank = resultType.getRank();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value dimSize = inputShape[dim];
auto adjustStartOrEnd = [&](Value startOrEndTorchType,
Value startOrEndBuiltin, Value valueForNone) {
if (startOrEndTorchType.getType().isa<Torch::NoneType>())
return valueForNone;
auto dimSizeAsInt = castIndexToInt(rewriter, loc, dimSize);
Value startOrEndToPositive =
toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt);
// startOrEnd < 0 ? 0 : startOrEnd
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
Value startOrEndAtLeastZero = rewriter.create<SelectOp>(
loc, predDimSltZero, cst0, startOrEndToPositive);
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
Value startOrEndBoundedByDimSize = rewriter.create<SelectOp>(
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
};
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
int64_t step;
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
if (!op.step().getType().isa<Torch::NoneType>())
return op->emitError("unimplemented: step is not constant");
step = 1;
}
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize =
rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dim] = resultSize;
SmallVector<Value> offsets(inputType.getRank(), zero);
SmallVector<Value> strides(inputType.getRank(), one);
offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides);
// TODO: This code is for selectOp, remove once squeeze dim is added
if (resultRank < inputType.getRank()) {
SmallVector<ReassociationIndices> reassociation(resultRank);
int64_t resultIdx = 0;
for (auto i : llvm::seq<int64_t>(0, inputType.getRank())) {
if (resultIdx < resultRank)
reassociation[resultIdx].push_back(i);
if (i != dim)
resultIdx++;
}
result = rewriter.create<linalg::TensorCollapseShapeOp>(loc, result,
reassociation);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
public:
@ -3265,6 +3366,8 @@ public:
patterns.add<ConvertAtenFill_ScalarOp>(typeConverter, context);
target.addIllegalOp<AtenNumelOp>();
patterns.add<ConvertAtenNumelOp>(typeConverter, context);
target.addIllegalOp<AtenSliceTensorOp>();
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -126,6 +126,26 @@ public:
};
} // namespace
namespace {
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSelectIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value end =
rewriter.create<AtenAddIntOp>(loc, one.getType(), op.index(), one);
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(op, op.getResult().getType(),
op.self(), op.dim(),
op.index(), end, one);
return success();
}
};
} // namespace
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
template <typename OpTy>
@ -487,6 +507,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddmmOp>();
patterns.add<DecomposeAtenMeanOp>(context);
target.addIllegalOp<AtenMeanOp>();
patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();

View File

@ -92,8 +92,8 @@ public:
} else if (isa<AtenSqueezeOp, AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
AtenTransposeIntOp, TensorStaticInfoCastOp,
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
AtenPermuteOp, AtenViewOp, AtenExpandOp,
AtenFill_ScalarOp>(op)) {
AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenFill_ScalarOp,
AtenSliceTensorOp, AtenSelectIntOp>(op)) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value

View File

@ -390,13 +390,11 @@ public:
};
return visitSliceLikeOp(indexSelect, operands, setDim);
} else if (auto selectInt = dyn_cast<AtenSelectIntOp>(op)) {
// Select one element from the target dim. All the other dims are the same
// as input.
// Slices along dim at index. Result shape same as input except dim is
// removed.
auto setDim = [](int64_t &targetDim, int64_t dim,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
targetDim = 1;
};
return visitSliceLikeOp(selectInt, operands, setDim);
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {};
return visitSliceLikeOp(selectInt, operands, setDim, /*keepDim=*/false);
} else if (auto sliceTensor = dyn_cast<AtenSliceTensorOp>(op)) {
// Select several elements from the target dim according to the start,
// end, step. All the other dims are the same as input.
@ -540,7 +538,7 @@ private:
template <typename OpTy>
ChangeResult
visitSliceLikeOp(OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
SetDimSizeFn setDim);
SetDimSizeFn setDim, bool keepDim = true);
ChangeResult
visitAtenGatherOp(AtenGatherOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@ -1222,7 +1220,7 @@ ChangeResult TypeAnalyzer::visitTypeConversionOp(
template <typename OpTy>
ChangeResult TypeAnalyzer::visitSliceLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
SetDimSizeFn setDim) {
SetDimSizeFn setDim, bool keepDim) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
@ -1248,6 +1246,8 @@ ChangeResult TypeAnalyzer::visitSliceLikeOp(
}
knowledge.sizes = input.sizes;
setDim(knowledge.sizes[dim], dim, operands);
if (!keepDim)
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
return getLatticeElement(op.getResult()).join(knowledge);
}

View File

@ -750,8 +750,8 @@ builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4],
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
// CHECK-SAME: %[[INDEX:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[RET:.*]] = torch.aten.select.int %[[INPUT]], %[[DIM]], %[[INDEX]] : !torch.tensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.tensor<[2,1,4],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,1,4],f32> to !torch.tensor
// CHECK: %[[RET:.*]] = torch.aten.select.int %[[INPUT]], %[[DIM]], %[[INDEX]] : !torch.tensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.tensor<[2,4],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,4],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.int) -> !torch.tensor {