[RefineTypes] Fix knowledge dtype for `aten.embedding` op

-- The dtype of the result of `aten.embedding` should match that of
   the `weight` operand's (operand[0]) instead of hardcoding to f32.
-- This commit aims to provide a fix for the same.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
pull/1680/head
Abhishek Varma 2022-11-30 11:14:48 +00:00 committed by Prashant Kumar
parent 577e38da58
commit 66d7a412cb
3 changed files with 27 additions and 1 deletions

View File

@ -155,6 +155,7 @@ MHLO_PASS_SET = {
"EmbeddingModuleI32Static_basic", "EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic", "EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic", "EmbeddingModuleI64_basic",
"EmbeddingModuleF16_basic",
"ExpandAsIntModule_basic", "ExpandAsIntModule_basic",
"ExpandModule_basic", "ExpandModule_basic",
"FullLikeModuleDefaultDtype_basic", "FullLikeModuleDefaultDtype_basic",

View File

@ -1088,7 +1088,7 @@ void TypeAnalysis::visitOperation(Operation *op,
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) { if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = Float32Type::get(op->getContext()); knowledge.dtype = operands[0]->getValue().dtype;
incorporateKnowledge(embedding.getResult(), knowledge); incorporateKnowledge(embedding.getResult(), knowledge);
return; return;
} }

View File

@ -807,6 +807,31 @@ class EmbeddingModuleI32(torch.nn.Module):
def EmbeddingModuleI32_basic(module, tu: TestUtils): def EmbeddingModuleI32_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 3, high=100).to(torch.int32)) module.forward(tu.randint(3, 3, high=100).to(torch.int32))
# ==============================================================================
class EmbeddingModuleF16(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.embed = torch.nn.Embedding(num_embeddings=100,
embedding_dim=50,
padding_idx=4).to(torch.half)
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, indices):
return self.embed.forward(indices)
@register_test_case(module_factory=lambda: EmbeddingModuleF16())
def EmbeddingModuleF16_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
# ============================================================================== # ==============================================================================