mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for aten.index_select (#3760)
- Add Torch to TOSA legalization for aten.index_select - Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to correctly convert Torch indices to TF-style indices, which is used in convertGatherNdOp - Update e2e tests in xfail_sets.py - Update basic.mlir with new LIT test for aten.index_select Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3766/head
parent
2374b9e02d
commit
e9ed4af9ce
|
@ -3821,6 +3821,124 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||
AtenIndexSelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Not a tensor type.
|
||||
auto input = adaptor.getSelf();
|
||||
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||
if (!inputType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only RankedTensorType inputs are currently supported");
|
||||
|
||||
auto index = adaptor.getIndex();
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
|
||||
if (!indexType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only RankedTensorType indices are currently supported");
|
||||
|
||||
auto inputShape = inputType.getShape();
|
||||
int inputRank = inputType.getRank();
|
||||
|
||||
if (indexType.getRank() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Rank 0 index tensor is currently not supported");
|
||||
|
||||
// Dynamic shape check
|
||||
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenIndexSelectOp: support for dynamic input "
|
||||
"shape not implemented");
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexType.getShape(),
|
||||
rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
|
||||
// Get positive dim
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Value `dim` should be a torch constant int");
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "Value `dim` is invalid");
|
||||
|
||||
// Get the output type
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// Reshape and expand the index tensor to have same rank and same dimensions
|
||||
// (except for the targeted dim) as the input
|
||||
//
|
||||
// For example:
|
||||
// Input shape = (4, 5, 6)
|
||||
// Index vector shape = (2)
|
||||
// Targeted dim = 1
|
||||
// Reshaped and expanded index vector shape = (4, 2, 6)
|
||||
//
|
||||
// By reshaping and expanding the index vector, we can supply it into the
|
||||
// gather op to mimic the functionality of aten.index_select
|
||||
SmallVector<int64_t> indicesInputRankShape;
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (i == dim) {
|
||||
indicesInputRankShape.push_back(indexType.getShape()[0]);
|
||||
} else {
|
||||
indicesInputRankShape.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
auto indicesInputRankType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape),
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto reshapedIndices = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), indicesInputRankType, index,
|
||||
rewriter.getDenseI64ArrayAttr(indicesInputRankShape));
|
||||
|
||||
SmallVector<int64_t> tileShape(indicesInputRankShape);
|
||||
SmallVector<int64_t> expandedIndicesShape(indicesInputRankShape);
|
||||
for (int64_t i = 0; i < inputRank; i++) {
|
||||
if (tileShape[i] == 1 && i != dim) {
|
||||
tileShape[i] = inputShape[i];
|
||||
expandedIndicesShape[i] = inputShape[i];
|
||||
} else {
|
||||
tileShape[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto tileType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto expandedIndices = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(), tileType, reshapedIndices.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(tileShape));
|
||||
|
||||
// convert torch style index and dim into tf style indices
|
||||
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
|
||||
auto indicesTf = tosa::convertTorchIndexToTfIndices(
|
||||
rewriter, op, input, expandedIndices.getResult(), dim);
|
||||
if (!indicesTf)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert TorchIndex To TfIndices failed");
|
||||
|
||||
// do the tf gathernd algorithm with tf style indices as input.
|
||||
auto result =
|
||||
tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value());
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
|
||||
|
@ -6240,6 +6358,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -23,6 +23,15 @@ namespace tosa {
|
|||
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// This function is a helper for `convertTorchIndexToTfIndices`.
|
||||
//
|
||||
// We convert PyTorch index to TensorFlow-style indices so that we can use
|
||||
// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather
|
||||
// and Scatter operators to TOSA using TensorFlow-style indices.
|
||||
// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow
|
||||
// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want
|
||||
// to gather/scatter elements, while in TensorFlow, the indices point directly
|
||||
// to positions that you want to gather/scatter elements.
|
||||
std::optional<Value>
|
||||
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
|
||||
|
@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
|||
unsigned indexRank = indexShape.size();
|
||||
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
|
||||
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
|
||||
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
|
||||
|
||||
// Create torch.meshgrid inputs
|
||||
// Example: indexShape=[1,4,2]
|
||||
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
|
||||
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
|
||||
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
|
||||
for (int i = 0; i < indexShape[dim]; i++) {
|
||||
for (int i = 0; i < indexShape[dim]; i++)
|
||||
indicesMetaElement.push_back(i);
|
||||
}
|
||||
|
||||
// Compute total number of meta element repeat times:
|
||||
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim
|
||||
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
|
||||
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
|
||||
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
|
||||
for (int i = 0; i < static_cast<int>(indexRank); i++) {
|
||||
if (i == dim) {
|
||||
continue;
|
||||
} else {
|
||||
indicesMetaElementRepeatTimes *= indexShape[i];
|
||||
}
|
||||
}
|
||||
int preDimMetaElementRepeatTimes = 1;
|
||||
int postDimMetaElementRepeatTimes = 1;
|
||||
|
||||
if (dim != static_cast<int>(indexShape.size()) - 1) {
|
||||
// Create one dim indices for index except for last dim
|
||||
// Create indices raw vector.
|
||||
// torch.stack(torch.meshgrid)
|
||||
// dim0: indicesVec = [0 0 0 0 0 0 0 0]
|
||||
// dim0: indicesVec = [0 0 1 1 2 2 3 3]
|
||||
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
||||
elementId++) {
|
||||
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
|
||||
indicesVec.push_back(indicesMetaElement[elementId]);
|
||||
}
|
||||
}
|
||||
} else { // Create the one dim indices for last dim of index
|
||||
// Create indices raw vector
|
||||
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
|
||||
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
|
||||
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
|
||||
// Compute total number of times meta element range should repeat
|
||||
// = product(indexShape[0:dim])
|
||||
// dim0: preDimMetaElementRepeatTimes = 1
|
||||
// dim1: preDimMetaElementRepeatTimes = 1
|
||||
// dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4
|
||||
for (int i = 0; i < dim; i++)
|
||||
preDimMetaElementRepeatTimes *= indexShape[i];
|
||||
|
||||
// Compute total number of times meta element repeat
|
||||
// = product(indexShape[dim+1:indexRank])
|
||||
// dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8
|
||||
// dim1: postDimMetaElementRepeatTimes = 2
|
||||
// dim2: postDimMetaElementRepeatTimes = 1
|
||||
for (int i = dim + 1; i < static_cast<int>(indexRank); i++)
|
||||
postDimMetaElementRepeatTimes *= indexShape[i];
|
||||
|
||||
// Example using dim1:
|
||||
// preDimMetaElementRepeatTimes = 1
|
||||
// postDimMetaElementRepeatTimes = 2
|
||||
// Using postDimMetaElementRepeatTimes, we get the meta element range:
|
||||
// [0 0 1 1 2 2 3 3]
|
||||
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
|
||||
// [0 0 1 1 2 2 3 3]
|
||||
//
|
||||
// Let's use a clearer example:
|
||||
// indexShape = [3, 4, 2]
|
||||
// Target dim = 1
|
||||
// => preDimMetaElementRepeatTimes = 3
|
||||
// postDimMetaElementRepeatTimes = 2
|
||||
// Using postDimMetaElementRepeatTimes, we get the meta element range:
|
||||
// [0 0 1 1 2 2]
|
||||
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
|
||||
// [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2]
|
||||
for (int i = 0; i < preDimMetaElementRepeatTimes; i++) {
|
||||
for (size_t elementId = 0; elementId < indicesMetaElement.size();
|
||||
elementId++) {
|
||||
for (int j = 0; j < postDimMetaElementRepeatTimes; j++) {
|
||||
indicesVec.push_back(indicesMetaElement[elementId]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1663,6 +1663,17 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"AtenLinalgCrossBroadcast_basic",
|
||||
"AtenLinalgCrossCustomDim_basic",
|
||||
"AtenLinalgCrossFloat_basic",
|
||||
"AtenLinalgCrossInt_basic",
|
||||
"AtenLinalgCrossNegativeDim_basic",
|
||||
"BinaryCrossEntropyWithLogitsStaticModule_basic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"DiagonalWithStaticShapeModule_basic",
|
||||
"EinsumStaticDiagonalDimensionModule_basic",
|
||||
"ElementwiseAtenFloorDivideBroadcastModule_basic",
|
||||
|
@ -2342,6 +2353,13 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
}
|
||||
) - {
|
||||
### Test failing in make_fx_tosa but not in tosa
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
|
@ -3205,6 +3223,17 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"AtenPolarDoubleModule_basic",
|
||||
"AtenPolarFloatModule_basic",
|
||||
|
@ -3302,12 +3331,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenItemIntOpModule_basic",
|
||||
"AtenLinalgCrossBroadcast_basic",
|
||||
"AtenLinalgCrossCustomDim_basic",
|
||||
"AtenLinalgCrossDynamic_basic",
|
||||
"AtenLinalgCrossFloat_basic",
|
||||
"AtenLinalgCrossInt_basic",
|
||||
"AtenLinalgCrossNegativeDim_basic",
|
||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
|
@ -3551,15 +3574,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectDynamicInputSizeModule_basic",
|
||||
"IndexSelectDynamicModulebasic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
|
@ -3848,6 +3863,8 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ReduceAllDimFloatModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
|
@ -4269,7 +4286,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseWhereSelfModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"EmptyLikeMemoryFormatModule_basic",
|
||||
|
@ -4363,12 +4379,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexSelectDynamicIndexSizeModule_basic",
|
||||
"IndexSelectDynamicInputSizeModule_basic",
|
||||
"IndexSelectDynamicModulebasic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"IndexSelectSingleIdxModule_basic",
|
||||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorHackedTwinModule3dInput_basic",
|
||||
|
@ -4386,10 +4396,8 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexTensorMultiInputOneDim_basic",
|
||||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
|
@ -4688,7 +4696,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ScatterValueFloatModule_basic",
|
||||
"ScatterValueIntModule_basic",
|
||||
"SelectIntModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SelectScattertModule_basic",
|
||||
"SelectScattertStaticModule_basic",
|
||||
"SignAndLogarithmOfDeterminantModule_F32",
|
||||
|
|
|
@ -1885,3 +1885,35 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
|
|||
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
|
||||
return %0 : !torch.vtensor<[5,6,2],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
|
||||
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
|
||||
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
|
||||
return %0 : !torch.vtensor<[4,5,2],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue