[TOSA] aten.index.tensor multiple indexes support (#1868)

pull/1832/head snapshot-20230214.749
Chi_Liu 2023-02-13 23:07:15 -08:00 committed by GitHub
parent 67ab708b63
commit 8a7340dfb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 120 additions and 10 deletions

View File

@ -644,6 +644,7 @@ TOSA_PASS_SET = {
"TypePromotionZeroRankHigherCategoryModule_basic",
"GatherStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",
@ -757,6 +758,7 @@ LTC_XFAIL_SET = {
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",

View File

@ -3328,17 +3328,106 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors = getTypeConvertedValues(rewriter, op->getLoc(),
getTypeConverter(), tensorsTorchType);
auto indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
// TODO add support for multiple index
if ( tensors.size() > 1){
return op.emitError(
"unimplemented: the index tensor list from list construct > 1");
auto outType = getTypeConverter()->convertType(op.getType());
// Support for multiple indexes
if (indexTensors.size() > 1) {
// t[i, i]
// = torch.ops.aten.index(t,(i,i))
// = tensor([[ t[1,1], t[2,2], t[3,3]],
// [ t[3,3], t[2,2], t[1,1]]])
// = tensor([[ 7, 13, 19], [19, 13, 7]])
// = tf.gather_nd(t,tf.ii_expand)
// ii_expand
// = tf.concat((i_expand,i_expand), dim=2)
// = tf.constant([[[1,1],[2,2],[3,3]],
// [[3,3],[2,2],[1,1]]]) # 2*3*2
SmallVector<Value> indicesTfConcatTensors;
SmallVector<int64_t> indexesRank;
SmallVector<SmallVector<int64_t>> indexesShape;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i];
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index input like torch.ops.aten.index(x,
// (None, index1, index2, None))
if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
indiceShapeOneDim.push_back(shape);
}
indiceShapeOneDim.push_back(1);
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
// create concat tensor for indicesTf
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
}
// Right now only support multiple indexes with same shape
// TODO for different shape multiple indexes, add broadcast_to for small
// shape
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
}
}
// concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2]
auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);
if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
auto index = tensors[0];
// Support for multiple index
auto index = indexTensors[0];
auto indexTorch = tensorsTorchType[0];
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
if (!index.getImpl())
if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
@ -3350,8 +3439,6 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}
auto outType = getTypeConverter()->convertType(op.getType());
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {

View File

@ -1826,6 +1826,27 @@ class IndexTensorStaticModule(torch.nn.Module):
def IndexTensorStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4))
# ==============================================================================
class IndexTensorMultiIndexStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5], torch.float32, True),
([2, 3], torch.int64, True),
([2, 3], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (index1, index2))
@register_test_case(module_factory=lambda: IndexTensorMultiIndexStaticModule())
def IndexTensorMultiIndexStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4), tu.randint(2, 3, high=4))
# ==============================================================================