From 3ca35b4f3cf10b08d6c3f2e19285899f440b3ac2 Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Fri, 21 Jul 2023 08:54:19 +0200 Subject: [PATCH] TorchToTosa: aten.embedding: Allow indices with any rank (#2327) It's actually fine to not check the rank of the indices, because the conversion anyways flattens the index tensor to be (1, numElements) before applying tosa::gather, and then anyways reshapes the output tensor to the output shape of the aten.embedding. --- e2e_testing/xfail_sets.py | 1 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 92d02dbfc..8b09187d1 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1017,6 +1017,7 @@ TOSA_PASS_SET = { "NumpyTRankNStaticModule_basic", "NumpyTRankNDynamicModule_basic", "EmbeddingModuleI32Static_basic", + "EmbeddingModule1DIndices_basic", "TModuleRank2_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 40dd8e24b..fd676deab 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2961,9 +2961,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); - if (indicesType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "indices must be of rank 2"); - auto weightType = weight.getType().cast(); if (weightType.getRank() != 2) return op.emitError("weight must be of rank 2");