mirror of https://github.com/llvm/torch-mlir
[FxImporter] Fix embedding bag (#3387)
parent
e0a5adb1db
commit
a5d3b546f8
|
@ -336,8 +336,6 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"AnyBoolFalseModule_basic",
|
"AnyBoolFalseModule_basic",
|
||||||
"AnyBoolTrueModule_basic",
|
"AnyBoolTrueModule_basic",
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
"AtenEmbeddingBagStaticModule_basic",
|
|
||||||
"AtenEmbeddingBagSumExample_basic",
|
|
||||||
"AtenFloatScalarModule_basic",
|
"AtenFloatScalarModule_basic",
|
||||||
"AtenIntBoolOpConstFalseModule_basic",
|
"AtenIntBoolOpConstFalseModule_basic",
|
||||||
"AtenIntBoolOpConstTrueModule_basic",
|
"AtenIntBoolOpConstTrueModule_basic",
|
||||||
|
|
|
@ -1446,6 +1446,21 @@ class GraphNodeImporter:
|
||||||
return
|
return
|
||||||
elif target == torch.ops.aten._unsafe_index_put.default:
|
elif target == torch.ops.aten._unsafe_index_put.default:
|
||||||
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
|
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
|
||||||
|
elif target == torch.ops.aten._embedding_bag_forward_only.default:
|
||||||
|
node.target = target = torch.ops.aten.embedding_bag.padding_idx
|
||||||
|
embedding_bag_args = [
|
||||||
|
("scale_grad_by_freq", False),
|
||||||
|
("mode", 0),
|
||||||
|
("sparse", False),
|
||||||
|
("per_sample_weights", None),
|
||||||
|
("include_last_offset", False),
|
||||||
|
("padding_idx", None),
|
||||||
|
]
|
||||||
|
node_kwargs = dict(node.kwargs)
|
||||||
|
for k, v in embedding_bag_args[len(node.args) - 3 :]:
|
||||||
|
if k not in node_kwargs:
|
||||||
|
node_kwargs[k] = v
|
||||||
|
node.kwargs = node_kwargs
|
||||||
|
|
||||||
schema = target._schema
|
schema = target._schema
|
||||||
assert isinstance(schema, FunctionSchema)
|
assert isinstance(schema, FunctionSchema)
|
||||||
|
|
Loading…
Reference in New Issue