mirror of https://github.com/llvm/torch-mlir
[RefineTypes] Fix knowledge dtype for `aten.embedding` op
-- The dtype of the result of `aten.embedding` should match that of the `weight` operand's (operand[0]) instead of hardcoding to f32. -- This commit aims to provide a fix for the same. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>pull/1680/head
parent
577e38da58
commit
66d7a412cb
|
@ -155,6 +155,7 @@ MHLO_PASS_SET = {
|
|||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"EmbeddingModuleF16_basic",
|
||||
"ExpandAsIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"FullLikeModuleDefaultDtype_basic",
|
||||
|
|
|
@ -1088,7 +1088,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
incorporateKnowledge(embedding.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -807,6 +807,31 @@ class EmbeddingModuleI32(torch.nn.Module):
|
|||
def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class EmbeddingModuleF16(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.embed = torch.nn.Embedding(num_embeddings=100,
|
||||
embedding_dim=50,
|
||||
padding_idx=4).to(torch.half)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, indices):
|
||||
return self.embed.forward(indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModuleF16())
|
||||
def EmbeddingModuleF16_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue