diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 428c19788..810db7268 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -215,10 +215,6 @@ TORCHDYNAMO_XFAIL_SET = { 'ConstantBoolParameterModule_basic', # START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "AddCDivModule_basic", - "ElementwiseMulScalarModule_basic", - "ElementwiseMulScalarModule_float", - "NativeGroupNormBackwardModule_basic", "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", @@ -226,23 +222,7 @@ TORCHDYNAMO_XFAIL_SET = { # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "AtenInstanceNormModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "ElementwiseAddScalarIntModule_basic", - "MobilenetV3Module_basic", - "NativeBatchNorm1DModule_basic", - "NativeBatchNorm2DModule_basic", - "NativeBatchNorm3DModule_basic", - "NativeBatchNormNoneWeightModule_basic", - "NativeGroupNormModule_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", # END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' @@ -255,9 +235,6 @@ TORCHDYNAMO_XFAIL_SET = { # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' "ElementwiseAtenDivIntScalarModule_basic", - # ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' - "ElementwiseMulScalarModule_int", - # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", @@ -315,10 +292,6 @@ TORCHDYNAMO_XFAIL_SET = { # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) "ArangeStartOutViewModule_basic", - # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", - # Dynamo does not support tracing quantized tensors "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", @@ -376,6 +349,9 @@ TORCHDYNAMO_CRASHING_SET = { "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticModule_basic", + + # Looks like incorrect fx graph conversion + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", } STABLEHLO_PASS_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index e5c2475c7..bdc410741 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -53,6 +53,40 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: return False return True +# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to +# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument +# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it +# needs to check argument types, which are not yet determined. +# Maybe schema or target should be changed, but it decided in +# _dynamo eval_frame on pytorch side. Also Python schema not matches +# with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +# So in general it covers some of overload cases, which done on Python side automatically. +# e.g. conversion Scalar -> Tensor and vice versa +def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule): + # Modify gm.graph + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + # call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {}) + if node.target == torch.ops.aten.add.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.add.Scalar + if node.target == torch.ops.aten.mul.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.mul.Scalar + + gm.graph.lint() # Does some checks to make sure the + + # Recompile the forward() method of `gm` from its Graph + gm.recompile() + def jit( model: torch.nn.Module, @@ -87,6 +121,8 @@ def jit( # way of differentiating between the two. assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" + scalarize_tensor_ops_on_scalars(gm) + nonlocal mlir_module *_, model_name, nth_graph = get_aot_compilation_context() mlir_module = import_fx_graph_as_func(gm.graph, model_name)