diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 578af98d1..14eaf3f5d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -336,8 +336,6 @@ FX_IMPORTER_XFAIL_SET = { "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "ArangeStartOutViewModule_basic", - "AtenEmbeddingBagStaticModule_basic", - "AtenEmbeddingBagSumExample_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 870cb8612..9981ed30e 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1446,6 +1446,21 @@ class GraphNodeImporter: return elif target == torch.ops.aten._unsafe_index_put.default: node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin + elif target == torch.ops.aten._embedding_bag_forward_only.default: + node.target = target = torch.ops.aten.embedding_bag.padding_idx + embedding_bag_args = [ + ("scale_grad_by_freq", False), + ("mode", 0), + ("sparse", False), + ("per_sample_weights", None), + ("include_last_offset", False), + ("padding_idx", None), + ] + node_kwargs = dict(node.kwargs) + for k, v in embedding_bag_args[len(node.args) - 3 :]: + if k not in node_kwargs: + node_kwargs[k] = v + node.kwargs = node_kwargs schema = target._schema assert isinstance(schema, FunctionSchema)