diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f49bc4e8b..53d468b17 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -387,7 +387,6 @@ TORCHDYNAMO_CRASHING_SET = { } FX_IMPORT_XFAIL_SET = { - 'AddIntModule_basic', 'AllBoolFalseModule_basic', 'AllBoolTrueModule_basic', 'AnyBoolFalseModule_basic', @@ -399,10 +398,7 @@ FX_IMPORT_XFAIL_SET = { 'AtenIntBoolOpConstFalseModule_basic', 'AtenIntBoolOpConstTrueModule_basic', 'AtenIntBoolOpModule_basic', - 'AtenIntTensorByteDtypeModule_basic', - 'AtenIntTensorCharDtypeModule_basic', 'AtenItemFpOpModule_basic', - 'AtenItemIntOpModule_basic', 'AtenMatmulQMixedSigni8Transpose_basic', 'AtenMatmulQMixedSigni8_basic', 'AtenMatmulQint8MV_basic', @@ -465,7 +461,6 @@ FX_IMPORT_XFAIL_SET = { 'MaxPool3dStaticCeilModeTrueModule_basic', 'MaxPool3dStaticModule_basic', 'MulFloatModule_basic', - 'MulIntModule_basic', 'NativeGroupNormBackwardModule_basic', 'NeFloatIntModule_basic', 'NeIntModule_basic', @@ -496,9 +491,6 @@ FX_IMPORT_XFAIL_SET = { 'RsubInt0d_NumToTensor_Module_basic', 'ScalarConstantTupleModule_basic', 'ScalarImplicitFloatModule_basic', - 'ScalarImplicitIntModule_basic', - 'ScatterValueFloatModule_basic', - 'ScatterValueIntModule_basic', 'SortIntListReverse_basic', 'SortIntList_basic', 'SplitDimDynamicModule_basic', @@ -506,14 +498,11 @@ FX_IMPORT_XFAIL_SET = { 'SqrtIntConstantModule_basic', 'SqrtIntModule_basic', 'SubFloatModule_basic', - 'SubIntModule_basic', 'TModuleRank0_basic', 'TensorToBoolZeroRank_basic', 'TensorToBool_basic', 'TensorToFloatZeroRank_basic', 'TensorToFloat_basic', - 'TensorToIntZeroRank_basic', - 'TensorToInt_basic', 'TestMultipleTensorAndPrimitiveTypesReturn_basic', 'ThresholdBackward2dMixedModule_basic', 'TorchPrimLoopForLikeModule_basic', diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 4fb1d12b0..d70c6046e 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1397,6 +1397,7 @@ class GraphNodeImporter: def _import_torch_op_overload( self, loc: Location, node: torch_fx.Node, target: TorchOpOverload ): + # TODO: Convert this cascade of ifs to a table-driven # replace lift_fresh_copy with clone op if target == torch.ops.aten.lift_fresh_copy.default: node.target = target = torch.ops.aten.clone.default @@ -1420,6 +1421,16 @@ class GraphNodeImporter: for key_node in node.users: if key_node.target == torch.ops.aten.baddbmm.default: node.target = target = torch.ops.aten.zeros.default + elif target == torch.ops.aten._local_scalar_dense.default: + input_type = node.args[0].meta["tensor_meta"].dtype + if input_type.is_floating_point: + node.target = target = torch.ops.aten.Float.Tensor + else: + node.target = target = torch.ops.aten.Int.Tensor + node.args = (node.args[0],) + elif target == torch.ops.aten._assert_async.msg: + # TODO: A more suitable op to replace it? + return schema = target._schema assert isinstance(schema, FunctionSchema)