[Tosa] : Add support for negative indices in index.tensor and index.Tensor_hacked_twin for TorchToTosa lowering. (#3790)

1. Negative indices for tensor indexing is handled by wrapping around
the index values by checking their values at run time. Without the fix,
there was a runtime error.
2. Added a lit test to lock down the behavior.
3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown
the behavior with e2e test as well.

"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY."
pull/3835/head
Sayan Saha 2024-10-25 18:37:19 -04:00 committed by GitHub
parent 54d9e24013
commit 2b01f8b7f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 82 additions and 39 deletions

View File

@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
return success(); return success();
} }
Value wrapNegativeIndices(Value index, int maxIndex, Operation *op,
ConversionPatternRewriter &rewriter) {
auto zeroValue = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
auto maxIndexValue =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
rewriter, op->getLoc(), indexType, maxIndexValue, index);
auto boolType = indexType.clone(rewriter.getIntegerType(1));
auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, index);
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
indexType, isNegativeIndices,
wrappedIndicesOp, index);
}
template <> template <>
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
auto outType = getTypeConverter()->convertType(op.getType()); auto outType = getTypeConverter()->convertType(op.getType());
Operation *indicesTf;
// Support for multiple indexes // Support for multiple indexes
if (indexTensors.size() > 1) { if (indexTensors.size() > 1) {
// t[i, i] // t[i, i]
@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
index); index);
} }
index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op,
rewriter);
// Expand last dim of index to tf indices [2,3] -> [2,3,1] // Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indiceShapeOneDim; SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) { for (auto shape : indexShape) {
@ -4299,57 +4322,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
auto indicesShapeConcat = indexesShape[0]; auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0]; uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size()); indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>( indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(), rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim); indicesTfConcatTensors, lastDim);
if (!indicesTf) { } else {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail."); // Single index
auto index = indexTensors[0];
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
} }
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
if (!result) { index =
return rewriter.notifyMatchFailure( wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter);
op, "Convert GatherNdOp fail for index tensor.");
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
} }
rewriter.replaceOp(op, {result.value()}); indicesShape.push_back(1);
indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
return success(); rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));
} }
// Support for multiple index
auto index = indexTensors[0];
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));
if (!indicesTf) { if (!indicesTf) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail."); "Convert TorchIndex To TfIndices fail.");
} }
// do the tf gathernp algorithm with tf style indices as input. // do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult()); indicesTf->getResult(0));
if (!result) { if (!result) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(

View File

@ -1698,7 +1698,6 @@ TOSA_CRASHING_SET = {
"ArangeStartOutModule_basic", "ArangeStartOutModule_basic",
"ScatterSrcStaticModule_basic", "ScatterSrcStaticModule_basic",
# Runtime op verification: Out of bounds access # Runtime op verification: Out of bounds access
"IndexTensorNegativeIndexModule_basic",
"ReduceAllDimEmpty_basic", "ReduceAllDimEmpty_basic",
} }
@ -1706,7 +1705,6 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
"ScatterSrcModule_basic", "ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic", "ScatterSrcStaticModule_basic",
"HBC_basic", "HBC_basic",
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_scales_recompute_bilinear",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
@ -2162,6 +2160,7 @@ TOSA_PASS_SET = {
"HardswishRandomModule_basic", "HardswishRandomModule_basic",
"HardtanhBackward_basic", "HardtanhBackward_basic",
"IndexTensorMultiIndexStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic",
"IndexTensorNegativeIndexModule_basic",
"IndexTensorStaticModule_basic", "IndexTensorStaticModule_basic",
"IscloseStaticModuleTrue_basic", "IscloseStaticModuleTrue_basic",
"IscloseStaticModule_basic", "IscloseStaticModule_basic",
@ -3635,7 +3634,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic", "IndexPutImplIndexWithNoneModule_basic",
"IndexSelectRank0IdxModule_basic", "IndexSelectRank0IdxModule_basic",
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateStaticModule_scales_bilinear_align_corners",

View File

@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
return %0 : !torch.vtensor<[2,3,4,4],f32> return %0 : !torch.vtensor<[2,3,4,4],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<i64>) -> tensor<i32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64>
func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
return %1 : !torch.vtensor<[4,2],si64>
}