[TOSA] Add legalization for aten.index_select (#3760)

- Add Torch to TOSA legalization for aten.index_select
- Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to
correctly convert Torch indices to TF-style indices, which is used in
convertGatherNdOp
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new LIT test for aten.index_select

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3766/head
Justin Ngo 2024-10-04 12:24:22 -07:00 committed by GitHub
parent 2374b9e02d
commit e9ed4af9ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 229 additions and 56 deletions

View File

@ -3821,6 +3821,124 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto input = adaptor.getSelf();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType inputs are currently supported");
auto index = adaptor.getIndex();
auto indexType = dyn_cast<RankedTensorType>(index.getType());
if (!indexType)
return rewriter.notifyMatchFailure(
op, "Only RankedTensorType indices are currently supported");
auto inputShape = inputType.getShape();
int inputRank = inputType.getRank();
if (indexType.getRank() == 0)
return rewriter.notifyMatchFailure(
op, "Rank 0 index tensor is currently not supported");
// Dynamic shape check
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "AtenIndexSelectOp: support for dynamic input "
"shape not implemented");
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexType.getShape(),
rewriter.getIntegerType(32)),
index);
}
// Get positive dim
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "Value `dim` should be a torch constant int");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "Value `dim` is invalid");
// Get the output type
auto outType = getTypeConverter()->convertType(op.getType());
// Reshape and expand the index tensor to have same rank and same dimensions
// (except for the targeted dim) as the input
//
// For example:
// Input shape = (4, 5, 6)
// Index vector shape = (2)
// Targeted dim = 1
// Reshaped and expanded index vector shape = (4, 2, 6)
//
// By reshaping and expanding the index vector, we can supply it into the
// gather op to mimic the functionality of aten.index_select
SmallVector<int64_t> indicesInputRankShape;
for (int64_t i = 0; i < inputRank; i++) {
if (i == dim) {
indicesInputRankShape.push_back(indexType.getShape()[0]);
} else {
indicesInputRankShape.push_back(1);
}
}
auto indicesInputRankType =
RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape),
rewriter.getIntegerType(32));
auto reshapedIndices = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesInputRankType, index,
rewriter.getDenseI64ArrayAttr(indicesInputRankShape));
SmallVector<int64_t> tileShape(indicesInputRankShape);
SmallVector<int64_t> expandedIndicesShape(indicesInputRankShape);
for (int64_t i = 0; i < inputRank; i++) {
if (tileShape[i] == 1 && i != dim) {
tileShape[i] = inputShape[i];
expandedIndicesShape[i] = inputShape[i];
} else {
tileShape[i] = 1;
}
}
auto tileType =
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
rewriter.getIntegerType(32));
auto expandedIndices = rewriter.create<tosa::TileOp>(
op->getLoc(), tileType, reshapedIndices.getResult(),
rewriter.getDenseI64ArrayAttr(tileShape));
// convert torch style index and dim into tf style indices
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
auto indicesTf = tosa::convertTorchIndexToTfIndices(
rewriter, op, input, expandedIndices.getResult(), dim);
if (!indicesTf)
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices failed");
// do the tf gathernd algorithm with tf style indices as input.
auto result =
tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value());
if (!result) {
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
@ -6240,6 +6358,7 @@ public:
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -23,6 +23,15 @@ namespace tosa {
using namespace mlir::torch::Torch;
// This function is a helper for `convertTorchIndexToTfIndices`.
//
// We convert PyTorch index to TensorFlow-style indices so that we can use
// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather
// and Scatter operators to TOSA using TensorFlow-style indices.
// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow
// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want
// to gather/scatter elements, while in TensorFlow, the indices point directly
// to positions that you want to gather/scatter elements.
std::optional<Value>
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
@ -30,49 +39,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
unsigned indexRank = indexShape.size();
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
// Create torch.meshgrid inputs
// Example: indexShape=[1,4,2]
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
for (int i = 0; i < indexShape[dim]; i++) {
for (int i = 0; i < indexShape[dim]; i++)
indicesMetaElement.push_back(i);
}
// Compute total number of meta element repeat times:
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
for (int i = 0; i < static_cast<int>(indexRank); i++) {
if (i == dim) {
continue;
} else {
indicesMetaElementRepeatTimes *= indexShape[i];
}
}
int preDimMetaElementRepeatTimes = 1;
int postDimMetaElementRepeatTimes = 1;
if (dim != static_cast<int>(indexShape.size()) - 1) {
// Create one dim indices for index except for last dim
// Create indices raw vector.
// torch.stack(torch.meshgrid)
// dim0: indicesVec = [0 0 0 0 0 0 0 0]
// dim0: indicesVec = [0 0 1 1 2 2 3 3]
for (size_t elementId = 0; elementId < indicesMetaElement.size();
elementId++) {
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
indicesVec.push_back(indicesMetaElement[elementId]);
}
}
} else { // Create the one dim indices for last dim of index
// Create indices raw vector
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
// Compute total number of times meta element range should repeat
// = product(indexShape[0:dim])
// dim0: preDimMetaElementRepeatTimes = 1
// dim1: preDimMetaElementRepeatTimes = 1
// dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4
for (int i = 0; i < dim; i++)
preDimMetaElementRepeatTimes *= indexShape[i];
// Compute total number of times meta element repeat
// = product(indexShape[dim+1:indexRank])
// dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8
// dim1: postDimMetaElementRepeatTimes = 2
// dim2: postDimMetaElementRepeatTimes = 1
for (int i = dim + 1; i < static_cast<int>(indexRank); i++)
postDimMetaElementRepeatTimes *= indexShape[i];
// Example using dim1:
// preDimMetaElementRepeatTimes = 1
// postDimMetaElementRepeatTimes = 2
// Using postDimMetaElementRepeatTimes, we get the meta element range:
// [0 0 1 1 2 2 3 3]
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
// [0 0 1 1 2 2 3 3]
//
// Let's use a clearer example:
// indexShape = [3, 4, 2]
// Target dim = 1
// => preDimMetaElementRepeatTimes = 3
// postDimMetaElementRepeatTimes = 2
// Using postDimMetaElementRepeatTimes, we get the meta element range:
// [0 0 1 1 2 2]
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
// [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2]
for (int i = 0; i < preDimMetaElementRepeatTimes; i++) {
for (size_t elementId = 0; elementId < indicesMetaElement.size();
elementId++) {
for (int j = 0; j < postDimMetaElementRepeatTimes; j++) {
indicesVec.push_back(indicesMetaElement[elementId]);
}
}

View File

@ -1663,6 +1663,17 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AtenLinalgCrossBroadcast_basic",
"AtenLinalgCrossCustomDim_basic",
"AtenLinalgCrossFloat_basic",
"AtenLinalgCrossInt_basic",
"AtenLinalgCrossNegativeDim_basic",
"BinaryCrossEntropyWithLogitsStaticModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"DiagonalWithStaticShapeModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
@ -2342,6 +2353,13 @@ MAKE_FX_TOSA_PASS_SET = (
}
) - {
### Test failing in make_fx_tosa but not in tosa
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
"MatmulStaticBroadcast_basic",
@ -3205,6 +3223,17 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
}
FX_IMPORTER_TOSA_XFAIL_SET = {
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"ElementwiseCreateComplexModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
@ -3302,12 +3331,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"AtenIntTensorCharDtypeModule_basic",
"AtenItemFpOpModule_basic",
"AtenItemIntOpModule_basic",
"AtenLinalgCrossBroadcast_basic",
"AtenLinalgCrossCustomDim_basic",
"AtenLinalgCrossDynamic_basic",
"AtenLinalgCrossFloat_basic",
"AtenLinalgCrossInt_basic",
"AtenLinalgCrossNegativeDim_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
@ -3551,15 +3574,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectRank0IdxModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
@ -3848,6 +3863,8 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"ElementwiseCreateComplexModule_basic",
"ReduceAllDimFloatModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"HstackBasicComplexModule_basic",
@ -4269,7 +4286,6 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseWhereSelfModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleF16_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"EmptyLikeMemoryFormatModule_basic",
@ -4363,12 +4379,6 @@ ONNX_TOSA_XFAIL_SET = {
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectRank0IdxModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinModule3dInput_basic",
@ -4386,10 +4396,8 @@ ONNX_TOSA_XFAIL_SET = {
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorNegativeIndexModule_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
@ -4688,7 +4696,6 @@ ONNX_TOSA_XFAIL_SET = {
"ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertModule_basic",
"SelectScattertStaticModule_basic",
"SignAndLogarithmOfDeterminantModule_F32",

View File

@ -1885,3 +1885,35 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
%0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32>
return %0 : !torch.vtensor<[5,6,2],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.index_select(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
// CHECK: }
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
%int2 = torch.constant.int 2
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
return %0 : !torch.vtensor<[4,5,2],f32>
}