diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 29108dbaa..f535184ba 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -102,34 +102,6 @@ TORCHDYNAMO_XFAIL_SET = { "UniformModule_basic", # error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion "TModuleRank1_basic", - # error: - "BatchMlpLayerModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "Conv2dBiasNoPaddingModule_basic", - "Conv2dNoPaddingModule_basic", - "Conv2dWithPaddingDilationStrideModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingModule_basic", - "EmbeddingModule1DIndices_basic", - "EmbeddingModuleI32Static_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "HBC_basic", - "LayerNormLastDimModule_basic", - "LayerNormModule_basic", - "LayerNormNormalizeOverAllDimsModule_basic", - "Mlp1LayerModule_basic", - "Mlp2LayerModuleNoBias_basic", - "Mlp2LayerModule_basic", - "MobilenetV3Module_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", - "SliceEndSleStartModule_basic", - "SliceOutOfUpperBoundIndexModule_basic", - "SliceStartEqEndModule_basic", } MHLO_PASS_SET = { diff --git a/python/test/debug/lockstep_basic.py b/python/test/debug/lockstep_basic.py index 62eb14a31..f653a7549 100644 --- a/python/test/debug/lockstep_basic.py +++ b/python/test/debug/lockstep_basic.py @@ -30,7 +30,9 @@ def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule, # TODO: As we get smarter about making this output more readable, we should # have more focused tests rather than this "check the exact output" test. -# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at None +# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at Module stack: {} +# CHECK-NEXT: File "{{.*}}python/test/debug/lockstep_basic.py", line {{.*}}, in f +# CHECK-NEXT: c = x / y @dynamo.optimize(miscompile_div_as_mul_backend) def f(x, y): a = x * y diff --git a/python/torch_mlir/dynamo.py b/python/torch_mlir/dynamo.py index f41cd8e97..5b580ca57 100644 --- a/python/torch_mlir/dynamo.py +++ b/python/torch_mlir/dynamo.py @@ -6,9 +6,10 @@ from typing import List import torch -from torch.fx.experimental.proxy_tensor import make_fx from torch._functorch.compile_utils import strip_overloads from torch._decomp import get_decompositions +from torch._dynamo.optimizations.training import aot_autograd +import functorch import warnings # https://github.com/pytorch/pytorch/issues/89064 @@ -38,31 +39,62 @@ def _get_decomposition_table(): aten._adaptive_avg_pool2d, aten.std.correction, aten.dot, + # TODO: Backends probably want to support this directly without + # decomposition. + # Our current situation with batch norm is a bit of a mess. + # aten.batch_norm has direct backend lowerings, + # aten.native_batch_norm gets decomposed into elementwise/reductions + # by DecomposeComplexOps (no backend marks it as backend-legal). + # Neither appears to support the "training" mode + # (the upstream decomposition we use here does), even though we have + # support for aten.native_batch_norm_backward. + aten._native_batch_norm_legit_functional, ]) -def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: - """Canonicalize single-element tuple returns to just the element. +def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool: + """Canonicalize the calling convention to the one that Torch-MLIR supports. + + The MLIR codebase currently supports importing functions that have either + a None return value, a single return value or a non-singleton tuple of + return values. But various situations create functions with single-element + tuples, or lists instead of tuples. This function adjusts the calling + conventions to match, and returns the information needed for the calling + code to reconstruct the original calling convention. Returns: - True if unwrapping took place, and false otherwise. + Two booleans + - The first indicates if a single-element tuple/list return + was converted to a return of the element itself. + - The second indicates if a list return was converted to a tuple. """ - did_unwrap = False - for node in fx_g.graph.nodes: + did_unwrap_single_element = False + did_convert_list_to_tuple = False + for node in gm.graph.nodes: if node.op == "output": - assert len( - node.args) == 1, "Output node must have a single argument" + assert len(node.args) == 1, \ + "Output node must have a single argument" node_arg = node.args[0] if isinstance(node_arg, tuple): if len(node_arg) == 1: node.args = (node_arg[0],) - did_unwrap = True + did_unwrap_single_element = True + break + if isinstance(node_arg, list): + if len(node_arg) == 1: + node.args = (node_arg[0],) + did_unwrap_single_element = True + did_convert_list_to_tuple = True + break + else: + node.args= (tuple(node_arg),) + did_convert_list_to_tuple = True break - if did_unwrap: - fx_g.graph.lint() - fx_g.recompile() - return did_unwrap + if did_unwrap_single_element: + gm.graph.lint() + gm.recompile() + return did_unwrap_single_element, did_convert_list_to_tuple def make_simple_dynamo_backend(user_backend): @@ -78,16 +110,24 @@ def make_simple_dynamo_backend(user_backend): Returns: A function with the signature used by TorchDynamo backends. """ - def wrapper_backend(fx_graph: torch.fx.GraphModule, + def wrapper_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - did_unwrap = _unwrap_single_tuple_return(fx_graph) - dispatcher_ops = make_fx( - fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs) - strip_overloads(dispatcher_ops) - user_callable = user_backend(dispatcher_ops, example_inputs) + did_unwrap_single_element, did_convert_list_to_tuple = \ + _adjust_calling_convention(gm) + strip_overloads(gm) + user_callable = user_backend(gm, example_inputs) + # TODO: Have a consistent story about the boxed calling convention. + # (for more details on this remove this decorator and look at the warning) + # See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale. + @functorch.compile.make_boxed_func def dynamo_callable(*inputs): result = user_callable(*inputs) - return (result,) if did_unwrap else result + if did_unwrap_single_element: + result = (result,) + if did_convert_list_to_tuple: + result = list(result) + return result return dynamo_callable - return wrapper_backend + return aot_autograd(fw_compiler=wrapper_backend, + decompositions=_get_decomposition_table)