mirror of https://github.com/llvm/torch-mlir
[FxImporter] Fix embedding bag (#3387)
parent
e0a5adb1db
commit
a5d3b546f8
|
@ -336,8 +336,6 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"AnyBoolFalseModule_basic",
|
||||
"AnyBoolTrueModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"AtenEmbeddingBagStaticModule_basic",
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
|
|
|
@ -1446,6 +1446,21 @@ class GraphNodeImporter:
|
|||
return
|
||||
elif target == torch.ops.aten._unsafe_index_put.default:
|
||||
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
|
||||
elif target == torch.ops.aten._embedding_bag_forward_only.default:
|
||||
node.target = target = torch.ops.aten.embedding_bag.padding_idx
|
||||
embedding_bag_args = [
|
||||
("scale_grad_by_freq", False),
|
||||
("mode", 0),
|
||||
("sparse", False),
|
||||
("per_sample_weights", None),
|
||||
("include_last_offset", False),
|
||||
("padding_idx", None),
|
||||
]
|
||||
node_kwargs = dict(node.kwargs)
|
||||
for k, v in embedding_bag_args[len(node.args) - 3 :]:
|
||||
if k not in node_kwargs:
|
||||
node_kwargs[k] = v
|
||||
node.kwargs = node_kwargs
|
||||
|
||||
schema = target._schema
|
||||
assert isinstance(schema, FunctionSchema)
|
||||
|
|
Loading…
Reference in New Issue