[LINALG] Fix typo in conversion pattern of `aten.embedding` (#942)

pull/792/merge
Ramiro Leal-Cavazos 2022-06-15 11:45:10 -05:00 committed by GitHub
parent 2a0c5de363
commit 246c2df65a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 1 deletions

View File

@ -133,7 +133,7 @@ public:
sizes.push_back(embeddingDim); sizes.push_back(embeddingDim);
int64_t resultRank = sizes.size(); int64_t resultRank = sizes.size();
auto indicesTy = weight.getType().cast<RankedTensorType>(); auto indicesTy = indices.getType().cast<RankedTensorType>();
int64_t indicesRank = indicesTy.getRank(); int64_t indicesRank = indicesTy.getRank();
SmallVector<AffineExpr> indicesExprs; SmallVector<AffineExpr> indicesExprs;
for (int i = 0; i < indicesRank; i++) for (int i = 0; i < indicesRank; i++)

View File

@ -666,6 +666,32 @@ def EmbeddingModuleI32_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class EmbeddingModule1DIndices(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.embed = torch.nn.Embedding(num_embeddings=100,
embedding_dim=50,
padding_idx=4)
@export
@annotate_args([
None,
([-1], torch.int32, True),
])
def forward(self, indices):
return self.embed.forward(indices)
@register_test_case(module_factory=lambda: EmbeddingModule1DIndices())
def EmbeddingModule1DIndices_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3,)).to(torch.int32))
# ==============================================================================
class SoftmaxIntModule(torch.nn.Module): class SoftmaxIntModule(torch.nn.Module):
def __init__(self): def __init__(self):