mirror of https://github.com/llvm/torch-mlir
[LINALG] Fix typo in conversion pattern of `aten.embedding` (#942)
parent
2a0c5de363
commit
246c2df65a
|
@ -133,7 +133,7 @@ public:
|
|||
sizes.push_back(embeddingDim);
|
||||
int64_t resultRank = sizes.size();
|
||||
|
||||
auto indicesTy = weight.getType().cast<RankedTensorType>();
|
||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
||||
int64_t indicesRank = indicesTy.getRank();
|
||||
SmallVector<AffineExpr> indicesExprs;
|
||||
for (int i = 0; i < indicesRank; i++)
|
||||
|
|
|
@ -666,6 +666,32 @@ def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class EmbeddingModule1DIndices(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.embed = torch.nn.Embedding(num_embeddings=100,
|
||||
embedding_dim=50,
|
||||
padding_idx=4)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
def forward(self, indices):
|
||||
return self.embed.forward(indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModule1DIndices())
|
||||
def EmbeddingModule1DIndices_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3,)).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftmaxIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue