mirror of https://github.com/llvm/torch-mlir
[Tosa] : Add support for negative indices in index.tensor and index.Tensor_hacked_twin for TorchToTosa lowering. (#3790)
1. Negative indices for tensor indexing is handled by wrapping around the index values by checking their values at run time. Without the fix, there was a runtime error. 2. Added a lit test to lock down the behavior. 3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown the behavior with e2e test as well. "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."pull/3835/head
parent
54d9e24013
commit
2b01f8b7f3
|
@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
Value wrapNegativeIndices(Value index, int maxIndex, Operation *op,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
|
||||
auto zeroValue = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
|
||||
auto maxIndexValue =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
|
||||
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
|
||||
auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
|
||||
rewriter, op->getLoc(), indexType, maxIndexValue, index);
|
||||
auto boolType = indexType.clone(rewriter.getIntegerType(1));
|
||||
auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
|
||||
rewriter, op->getLoc(), boolType, zeroValue, index);
|
||||
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
|
||||
indexType, isNegativeIndices,
|
||||
wrappedIndicesOp, index);
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
|
||||
|
@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Operation *indicesTf;
|
||||
|
||||
// Support for multiple indexes
|
||||
if (indexTensors.size() > 1) {
|
||||
// t[i, i]
|
||||
|
@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
index);
|
||||
}
|
||||
|
||||
index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op,
|
||||
rewriter);
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indiceShapeOneDim;
|
||||
for (auto shape : indexShape) {
|
||||
|
@ -4299,57 +4322,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
auto indicesShapeConcat = indexesShape[0];
|
||||
uint64_t lastDim = indexesRank[0];
|
||||
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
|
||||
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||
indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
||||
indicesTfConcatTensors, lastDim);
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert TorchIndex To TfIndices fail.");
|
||||
} else {
|
||||
|
||||
// Single index
|
||||
auto index = indexTensors[0];
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
// do the tf gathernp algorithm with tf style indices as input.
|
||||
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult());
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert GatherNdOp fail for index tensor.");
|
||||
index =
|
||||
wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter);
|
||||
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indicesShape;
|
||||
for (auto shape : indexShape) {
|
||||
indicesShape.push_back(shape);
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
indicesShape.push_back(1);
|
||||
indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
|
||||
rewriter.getDenseI64ArrayAttr(indicesShape));
|
||||
}
|
||||
|
||||
// Support for multiple index
|
||||
auto index = indexTensors[0];
|
||||
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||
auto indexShape = indexType.getShape();
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
|
||||
}
|
||||
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indicesShape;
|
||||
for (auto shape : indexShape) {
|
||||
indicesShape.push_back(shape);
|
||||
}
|
||||
indicesShape.push_back(1);
|
||||
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
|
||||
rewriter.getDenseI64ArrayAttr(indicesShape));
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Convert TorchIndex To TfIndices fail.");
|
||||
}
|
||||
// do the tf gathernp algorithm with tf style indices as input.
|
||||
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult());
|
||||
indicesTf->getResult(0));
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -1698,7 +1698,6 @@ TOSA_CRASHING_SET = {
|
|||
"ArangeStartOutModule_basic",
|
||||
"ScatterSrcStaticModule_basic",
|
||||
# Runtime op verification: Out of bounds access
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"ReduceAllDimEmpty_basic",
|
||||
}
|
||||
|
||||
|
@ -1706,7 +1705,6 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
"ScatterSrcModule_basic",
|
||||
"ScatterSrcStaticModule_basic",
|
||||
"HBC_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
|
@ -2162,6 +2160,7 @@ TOSA_PASS_SET = {
|
|||
"HardswishRandomModule_basic",
|
||||
"HardtanhBackward_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
"IscloseStaticModule_basic",
|
||||
|
@ -3635,7 +3634,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
|
|
|
@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t
|
|||
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
|
||||
return %0 : !torch.vtensor<[2,3,4,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor<i64>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64>
|
||||
|
||||
func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
|
||||
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
|
||||
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
|
||||
return %1 : !torch.vtensor<[4,2],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue