[FxImporter] Fix failed e2e case (#3365)

pull/3334/head
penguin_wwy 2024-05-22 00:20:54 +08:00 committed by GitHub
parent b870729efe
commit c2c1c2cfa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 3 deletions

View File

@ -455,9 +455,6 @@ FX_IMPORTER_XFAIL_SET = {
"ThresholdBackward2dMixedModule_basic", "ThresholdBackward2dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic", "TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic",

View File

@ -47,6 +47,7 @@ DEFAULT_DECOMPOSITIONS = [
torch.ops.aten.linspace.default, torch.ops.aten.linspace.default,
torch.ops.aten.triu.default, torch.ops.aten.triu.default,
torch.ops.aten.nan_to_num.default, torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind,
] ]

View File

@ -1428,6 +1428,8 @@ class GraphNodeImporter:
elif target == torch.ops.aten._assert_async.msg: elif target == torch.ops.aten._assert_async.msg:
# TODO: A more suitable op to replace it? # TODO: A more suitable op to replace it?
return return
elif target == torch.ops.aten._unsafe_index_put.default:
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
schema = target._schema schema = target._schema
assert isinstance(schema, FunctionSchema) assert isinstance(schema, FunctionSchema)