TorchToTosa: aten.embedding: Allow indices with any rank (#2327)

It's actually fine to not check the rank of the indices, because the conversion anyways flattens the index tensor to be (1, numElements) before applying tosa::gather, and then anyways reshapes the output tensor to the output shape of the aten.embedding.
pull/2328/head snapshot-20230721.906
Matthias Gehre 2023-07-21 08:54:19 +02:00 committed by GitHub
parent 1e468e8294
commit 3ca35b4f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 3 deletions

View File

@ -1017,6 +1017,7 @@ TOSA_PASS_SET = {
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModule1DIndices_basic",
"TModuleRank2_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",

View File

@ -2961,9 +2961,6 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Indices must be of integer tensor type");
if (indicesType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "indices must be of rank 2");
auto weightType = weight.getType().cast<RankedTensorType>();
if (weightType.getRank() != 2)
return op.emitError("weight must be of rank 2");