mirror of https://github.com/llvm/torch-mlir
[RefineTypes] Fix knowledge dtype for `aten.embedding` op
-- The dtype of the result of `aten.embedding` should match that of the `weight` operand's (operand[0]) instead of hardcoding to f32. -- This commit aims to provide a fix for the same. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>pull/1680/head
parent
577e38da58
commit
66d7a412cb
|
@ -155,6 +155,7 @@ MHLO_PASS_SET = {
|
||||||
"EmbeddingModuleI32Static_basic",
|
"EmbeddingModuleI32Static_basic",
|
||||||
"EmbeddingModuleI32_basic",
|
"EmbeddingModuleI32_basic",
|
||||||
"EmbeddingModuleI64_basic",
|
"EmbeddingModuleI64_basic",
|
||||||
|
"EmbeddingModuleF16_basic",
|
||||||
"ExpandAsIntModule_basic",
|
"ExpandAsIntModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
"FullLikeModuleDefaultDtype_basic",
|
"FullLikeModuleDefaultDtype_basic",
|
||||||
|
|
|
@ -1088,7 +1088,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = Float32Type::get(op->getContext());
|
knowledge.dtype = operands[0]->getValue().dtype;
|
||||||
incorporateKnowledge(embedding.getResult(), knowledge);
|
incorporateKnowledge(embedding.getResult(), knowledge);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -807,6 +807,31 @@ class EmbeddingModuleI32(torch.nn.Module):
|
||||||
def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModuleF16(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
self.embed = torch.nn.Embedding(num_embeddings=100,
|
||||||
|
embedding_dim=50,
|
||||||
|
padding_idx=4).to(torch.half)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, indices):
|
||||||
|
return self.embed.forward(indices)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmbeddingModuleF16())
|
||||||
|
def EmbeddingModuleF16_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue