mirror of https://github.com/llvm/torch-mlir
TorchToTosa: aten.embedding: Allow indices with any rank (#2327)
It's actually fine to not check the rank of the indices, because the conversion anyways flattens the index tensor to be (1, numElements) before applying tosa::gather, and then anyways reshapes the output tensor to the output shape of the aten.embedding.pull/2328/head snapshot-20230721.906
parent
1e468e8294
commit
3ca35b4f3c
|
@ -1017,6 +1017,7 @@ TOSA_PASS_SET = {
|
|||
"NumpyTRankNStaticModule_basic",
|
||||
"NumpyTRankNDynamicModule_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"TModuleRank2_basic",
|
||||
"TransposeIntModule_basic",
|
||||
"TransposeIntNegDimsModule_basic",
|
||||
|
|
|
@ -2961,9 +2961,6 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Indices must be of integer tensor type");
|
||||
|
||||
if (indicesType.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "indices must be of rank 2");
|
||||
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
if (weightType.getRank() != 2)
|
||||
return op.emitError("weight must be of rank 2");
|
||||
|
|
Loading…
Reference in New Issue