diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 92d02dbfc..8b09187d1 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1017,6 +1017,7 @@ TOSA_PASS_SET = { "NumpyTRankNStaticModule_basic", "NumpyTRankNDynamicModule_basic", "EmbeddingModuleI32Static_basic", + "EmbeddingModule1DIndices_basic", "TModuleRank2_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 40dd8e24b..fd676deab 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2961,9 +2961,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); - if (indicesType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "indices must be of rank 2"); - auto weightType = weight.getType().cast(); if (weightType.getRank() != 2) return op.emitError("weight must be of rank 2");