[torchdynamo] Move to aot_autograd instead of raw make_fx

As [@ezyang suggested](https://github.com/pytorch/pytorch/issues/90276#issuecomment-1339791275),
use `torch._dynamo.optimizations.training.aot_autograd` instead of raw
`make_fx`. This is more future proof and gives us the backward pass and
functionalization. We don't currently get functionalization because of
https://github.com/pytorch/pytorch/issues/90759

This also incidentally fixes the source location handling, which makes
`lockstep_basic.py` give an accurate source location!
pull/1724/head snapshot-20221215.688
Sean Silva 2022-12-12 12:51:37 +00:00
parent 64f9a0e978
commit af9e8a5e63
3 changed files with 64 additions and 50 deletions

View File

@ -102,34 +102,6 @@ TORCHDYNAMO_XFAIL_SET = {
"UniformModule_basic", "UniformModule_basic",
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion # error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
"TModuleRank1_basic", "TModuleRank1_basic",
# error:
"BatchMlpLayerModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"HBC_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModuleNoBias_basic",
"Mlp2LayerModule_basic",
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"SliceEndSleStartModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceStartEqEndModule_basic",
} }
MHLO_PASS_SET = { MHLO_PASS_SET = {

View File

@ -30,7 +30,9 @@ def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule,
# TODO: As we get smarter about making this output more readable, we should # TODO: As we get smarter about making this output more readable, we should
# have more focused tests rather than this "check the exact output" test. # have more focused tests rather than this "check the exact output" test.
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at None # CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at Module stack: {}
# CHECK-NEXT: File "{{.*}}python/test/debug/lockstep_basic.py", line {{.*}}, in f
# CHECK-NEXT: c = x / y
@dynamo.optimize(miscompile_div_as_mul_backend) @dynamo.optimize(miscompile_div_as_mul_backend)
def f(x, y): def f(x, y):
a = x * y a = x * y

View File

@ -6,9 +6,10 @@
from typing import List from typing import List
import torch import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads from torch._functorch.compile_utils import strip_overloads
from torch._decomp import get_decompositions from torch._decomp import get_decompositions
from torch._dynamo.optimizations.training import aot_autograd
import functorch
import warnings import warnings
# https://github.com/pytorch/pytorch/issues/89064 # https://github.com/pytorch/pytorch/issues/89064
@ -38,31 +39,62 @@ def _get_decomposition_table():
aten._adaptive_avg_pool2d, aten._adaptive_avg_pool2d,
aten.std.correction, aten.std.correction,
aten.dot, aten.dot,
# TODO: Backends probably want to support this directly without
# decomposition.
# Our current situation with batch norm is a bit of a mess.
# aten.batch_norm has direct backend lowerings,
# aten.native_batch_norm gets decomposed into elementwise/reductions
# by DecomposeComplexOps (no backend marks it as backend-legal).
# Neither appears to support the "training" mode
# (the upstream decomposition we use here does), even though we have
# support for aten.native_batch_norm_backward.
aten._native_batch_norm_legit_functional,
]) ])
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
"""Canonicalize single-element tuple returns to just the element. """Canonicalize the calling convention to the one that Torch-MLIR supports.
The MLIR codebase currently supports importing functions that have either
a None return value, a single return value or a non-singleton tuple of
return values. But various situations create functions with single-element
tuples, or lists instead of tuples. This function adjusts the calling
conventions to match, and returns the information needed for the calling
code to reconstruct the original calling convention.
Returns: Returns:
True if unwrapping took place, and false otherwise. Two booleans
- The first indicates if a single-element tuple/list return
was converted to a return of the element itself.
- The second indicates if a list return was converted to a tuple.
""" """
did_unwrap = False did_unwrap_single_element = False
for node in fx_g.graph.nodes: did_convert_list_to_tuple = False
for node in gm.graph.nodes:
if node.op == "output": if node.op == "output":
assert len( assert len(node.args) == 1, \
node.args) == 1, "Output node must have a single argument" "Output node must have a single argument"
node_arg = node.args[0] node_arg = node.args[0]
if isinstance(node_arg, tuple): if isinstance(node_arg, tuple):
if len(node_arg) == 1: if len(node_arg) == 1:
node.args = (node_arg[0],) node.args = (node_arg[0],)
did_unwrap = True did_unwrap_single_element = True
break
if isinstance(node_arg, list):
if len(node_arg) == 1:
node.args = (node_arg[0],)
did_unwrap_single_element = True
did_convert_list_to_tuple = True
break
else:
node.args= (tuple(node_arg),)
did_convert_list_to_tuple = True
break break
if did_unwrap: if did_unwrap_single_element:
fx_g.graph.lint() gm.graph.lint()
fx_g.recompile() gm.recompile()
return did_unwrap return did_unwrap_single_element, did_convert_list_to_tuple
def make_simple_dynamo_backend(user_backend): def make_simple_dynamo_backend(user_backend):
@ -78,16 +110,24 @@ def make_simple_dynamo_backend(user_backend):
Returns: Returns:
A function with the signature used by TorchDynamo backends. A function with the signature used by TorchDynamo backends.
""" """
def wrapper_backend(fx_graph: torch.fx.GraphModule, def wrapper_backend(gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]): example_inputs: List[torch.Tensor]):
did_unwrap = _unwrap_single_tuple_return(fx_graph) did_unwrap_single_element, did_convert_list_to_tuple = \
dispatcher_ops = make_fx( _adjust_calling_convention(gm)
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs) strip_overloads(gm)
strip_overloads(dispatcher_ops) user_callable = user_backend(gm, example_inputs)
user_callable = user_backend(dispatcher_ops, example_inputs)
# TODO: Have a consistent story about the boxed calling convention.
# (for more details on this remove this decorator and look at the warning)
# See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
@functorch.compile.make_boxed_func
def dynamo_callable(*inputs): def dynamo_callable(*inputs):
result = user_callable(*inputs) result = user_callable(*inputs)
return (result,) if did_unwrap else result if did_unwrap_single_element:
result = (result,)
if did_convert_list_to_tuple:
result = list(result)
return result
return dynamo_callable return dynamo_callable
return wrapper_backend return aot_autograd(fw_compiler=wrapper_backend,
decompositions=_get_decomposition_table)