From 246c2df65a01998b2159fec79675a0631f1f0218 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 15 Jun 2022 11:45:10 -0500 Subject: [PATCH] [LINALG] Fix typo in conversion pattern of `aten.embedding` (#942) --- .../TorchToLinalg/IndirectDataMovement.cpp | 2 +- .../torch_mlir_e2e_test/test_suite/basic.py | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 1fa8948a5..f93f3923e 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -133,7 +133,7 @@ public: sizes.push_back(embeddingDim); int64_t resultRank = sizes.size(); - auto indicesTy = weight.getType().cast(); + auto indicesTy = indices.getType().cast(); int64_t indicesRank = indicesTy.getRank(); SmallVector indicesExprs; for (int i = 0; i < indicesRank; i++) diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index c9b6b7c92..df8991ba6 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -666,6 +666,32 @@ def EmbeddingModuleI32_basic(module, tu: TestUtils): # ============================================================================== +class EmbeddingModule1DIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.embed = torch.nn.Embedding(num_embeddings=100, + embedding_dim=50, + padding_idx=4) + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ]) + def forward(self, indices): + return self.embed.forward(indices) + + +@register_test_case(module_factory=lambda: EmbeddingModule1DIndices()) +def EmbeddingModule1DIndices_basic(module, tu: TestUtils): + module.forward(torch.randint(100, (3,)).to(torch.int32)) + + +# ============================================================================== + + class SoftmaxIntModule(torch.nn.Module): def __init__(self):