mirror of https://github.com/llvm/torch-mlir
[LINALG] Fix typo in conversion pattern of `aten.embedding` (#942)
parent
2a0c5de363
commit
246c2df65a
|
@ -133,7 +133,7 @@ public:
|
||||||
sizes.push_back(embeddingDim);
|
sizes.push_back(embeddingDim);
|
||||||
int64_t resultRank = sizes.size();
|
int64_t resultRank = sizes.size();
|
||||||
|
|
||||||
auto indicesTy = weight.getType().cast<RankedTensorType>();
|
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
||||||
int64_t indicesRank = indicesTy.getRank();
|
int64_t indicesRank = indicesTy.getRank();
|
||||||
SmallVector<AffineExpr> indicesExprs;
|
SmallVector<AffineExpr> indicesExprs;
|
||||||
for (int i = 0; i < indicesRank; i++)
|
for (int i = 0; i < indicesRank; i++)
|
||||||
|
|
|
@ -666,6 +666,32 @@ def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModule1DIndices(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
self.embed = torch.nn.Embedding(num_embeddings=100,
|
||||||
|
embedding_dim=50,
|
||||||
|
padding_idx=4)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, indices):
|
||||||
|
return self.embed.forward(indices)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmbeddingModule1DIndices())
|
||||||
|
def EmbeddingModule1DIndices_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(100, (3,)).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxIntModule(torch.nn.Module):
|
class SoftmaxIntModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue