[FxImporter] Fix embedding bag (#3387)

pull/3386/head
penguin_wwy 2024-05-29 14:46:21 +08:00 committed by GitHub
parent e0a5adb1db
commit a5d3b546f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 2 deletions

View File

@ -336,8 +336,6 @@ FX_IMPORTER_XFAIL_SET = {
"AnyBoolFalseModule_basic", "AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic", "AnyBoolTrueModule_basic",
"ArangeStartOutViewModule_basic", "ArangeStartOutViewModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"AtenEmbeddingBagSumExample_basic",
"AtenFloatScalarModule_basic", "AtenFloatScalarModule_basic",
"AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpConstTrueModule_basic",

View File

@ -1446,6 +1446,21 @@ class GraphNodeImporter:
return return
elif target == torch.ops.aten._unsafe_index_put.default: elif target == torch.ops.aten._unsafe_index_put.default:
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
elif target == torch.ops.aten._embedding_bag_forward_only.default:
node.target = target = torch.ops.aten.embedding_bag.padding_idx
embedding_bag_args = [
("scale_grad_by_freq", False),
("mode", 0),
("sparse", False),
("per_sample_weights", None),
("include_last_offset", False),
("padding_idx", None),
]
node_kwargs = dict(node.kwargs)
for k, v in embedding_bag_args[len(node.args) - 3 :]:
if k not in node_kwargs:
node_kwargs[k] = v
node.kwargs = node_kwargs
schema = target._schema schema = target._schema
assert isinstance(schema, FunctionSchema) assert isinstance(schema, FunctionSchema)