From 8a7340dfb5d40a7e52dbb4fb67797d704c7b2dde Mon Sep 17 00:00:00 2001 From: Chi_Liu Date: Mon, 13 Feb 2023 23:07:15 -0800 Subject: [PATCH] [TOSA] aten.index.tensor multiple indexes support (#1868) --- e2e_testing/xfail_sets.py | 2 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 107 ++++++++++++++++-- .../torch_mlir_e2e_test/test_suite/basic.py | 21 ++++ 3 files changed, 120 insertions(+), 10 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 92661628e..ee39b1603 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -644,6 +644,7 @@ TOSA_PASS_SET = { "TypePromotionZeroRankHigherCategoryModule_basic", "GatherStaticModule_basic", "IndexTensorStaticModule_basic", + "IndexTensorMultiIndexStaticModule_basic", "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", @@ -757,6 +758,7 @@ LTC_XFAIL_SET = { "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", "IndexTensorStaticModule_basic", + "IndexTensorMultiIndexStaticModule_basic", "IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputNonContiguous_basic", "IndexTensorMultiInputOneDim_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 7b9acb86e..57dfa0c14 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3328,17 +3328,106 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!getListConstructElements(tensorList, tensorsTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); - auto tensors = getTypeConvertedValues(rewriter, op->getLoc(), - getTypeConverter(), tensorsTorchType); + auto indexTensors = getTypeConvertedValues( + rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType); - // TODO add support for multiple index - if ( tensors.size() > 1){ - return op.emitError( - "unimplemented: the index tensor list from list construct > 1"); + auto outType = getTypeConverter()->convertType(op.getType()); + + // Support for multiple indexes + if (indexTensors.size() > 1) { + // t[i, i] + // = torch.ops.aten.index(t,(i,i)) + // = tensor([[ t[1,1], t[2,2], t[3,3]], + // [ t[3,3], t[2,2], t[1,1]]]) + // = tensor([[ 7, 13, 19], [19, 13, 7]]) + // = tf.gather_nd(t,tf.ii_expand) + // ii_expand + // = tf.concat((i_expand,i_expand), dim=2) + // = tf.constant([[[1,1],[2,2],[3,3]], + // [[3,3],[2,2],[1,1]]]) # 2*3*2 + SmallVector indicesTfConcatTensors; + SmallVector indexesRank; + SmallVector> indexesShape; + + // concat index tensor into to indices tensor for concat + for (size_t i = 0; i < indexTensors.size(); i++) { + auto index = indexTensors[i]; + auto indexTorch = tensorsTorchType[i]; + // TODO add support for none index input like torch.ops.aten.index(x, + // (None, index1, index2, None)) + if (indexTorch.getType().isa()) + return rewriter.notifyMatchFailure( + op, "Only list ranked tensor types index are supported"); + + auto indexType = index.getType().dyn_cast(); + auto indexShape = indexType.getShape(); + indexesShape.push_back(makeShapeTorchCompatible(indexShape)); + indexesRank.push_back(indexType.getRank()); + + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), + index); + } + + // Expand last dim of index to tf indices [2,3] -> [2,3,1] + SmallVector indiceShapeOneDim; + for (auto shape : indexShape) { + indiceShapeOneDim.push_back(shape); + } + indiceShapeOneDim.push_back(1); + auto indicesTfOneDim = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), + index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim)); + + // create concat tensor for indicesTf + indicesTfConcatTensors.push_back(indicesTfOneDim.getResult()); + } + + // Right now only support multiple indexes with same shape + // TODO for different shape multiple indexes, add broadcast_to for small + // shape + for (auto indexShapeOneDim : indexesShape) { + if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: Only support multi indexes with same shape"); + } + } + + // concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2] + auto indicesShapeConcat = indexesShape[0]; + uint64_t lastDim = indexesRank[0]; + indicesShapeConcat.push_back(indicesTfConcatTensors.size()); + auto indicesTf = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), + indicesTfConcatTensors, lastDim); + + if (!indicesTf) { + return rewriter.notifyMatchFailure( + op, "Convert TorchIndex To TfIndices fail."); + } + // do the tf gathernp algorithm with tf style indices as input. + auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, + indicesTf.getResult()); + + if (!result) { + return rewriter.notifyMatchFailure( + op, "Convert GatherNdOp fail for index tensor."); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); } - auto index = tensors[0]; + + // Support for multiple index + auto index = indexTensors[0]; + auto indexTorch = tensorsTorchType[0]; // TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None)) - if (!index.getImpl()) + if (indexTorch.getType().isa()) return rewriter.notifyMatchFailure( op, "Only list ranked tensor types index are supported"); auto indexType = index.getType().dyn_cast(); @@ -3350,8 +3439,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); } - auto outType = getTypeConverter()->convertType(op.getType()); - // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indicesShape; for (auto shape : indexShape) { diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index e10a06f4d..f3dfad58a 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1826,6 +1826,27 @@ class IndexTensorStaticModule(torch.nn.Module): def IndexTensorStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4)) +# ============================================================================== +class IndexTensorMultiIndexStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5], torch.float32, True), + ([2, 3], torch.int64, True), + ([2, 3], torch.int64, True), + ]) + def forward(self, x, index1, index2): + return torch.ops.aten.index(x, (index1, index2)) + + +@register_test_case(module_factory=lambda: IndexTensorMultiIndexStaticModule()) +def IndexTensorMultiIndexStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4), tu.randint(2, 3, high=4)) + # ==============================================================================