[torchdynamo] Move to aot_autograd instead of raw make_fx

As [@ezyang suggested](https://github.com/pytorch/pytorch/issues/90276#issuecomment-1339791275),
use `torch._dynamo.optimizations.training.aot_autograd` instead of raw
`make_fx`. This is more future proof and gives us the backward pass and
functionalization. We don't currently get functionalization because of
https://github.com/pytorch/pytorch/issues/90759

This also incidentally fixes the source location handling, which makes
`lockstep_basic.py` give an accurate source location!
pull/1724/head snapshot-20221215.688
Sean Silva 2022-12-12 12:51:37 +00:00
parent 64f9a0e978
commit af9e8a5e63
3 changed files with 64 additions and 50 deletions

View File

@ -102,34 +102,6 @@ TORCHDYNAMO_XFAIL_SET = {
"UniformModule_basic",
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
"TModuleRank1_basic",
# error:
"BatchMlpLayerModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"Conv2dBiasNoPaddingModule_basic",
"Conv2dNoPaddingModule_basic",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"HBC_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModuleNoBias_basic",
"Mlp2LayerModule_basic",
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"SliceEndSleStartModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceStartEqEndModule_basic",
}
MHLO_PASS_SET = {

View File

@ -30,7 +30,9 @@ def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule,
# TODO: As we get smarter about making this output more readable, we should
# have more focused tests rather than this "check the exact output" test.
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at None
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at Module stack: {}
# CHECK-NEXT: File "{{.*}}python/test/debug/lockstep_basic.py", line {{.*}}, in f
# CHECK-NEXT: c = x / y
@dynamo.optimize(miscompile_div_as_mul_backend)
def f(x, y):
a = x * y

View File

@ -6,9 +6,10 @@
from typing import List
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from torch._decomp import get_decompositions
from torch._dynamo.optimizations.training import aot_autograd
import functorch
import warnings
# https://github.com/pytorch/pytorch/issues/89064
@ -38,31 +39,62 @@ def _get_decomposition_table():
aten._adaptive_avg_pool2d,
aten.std.correction,
aten.dot,
# TODO: Backends probably want to support this directly without
# decomposition.
# Our current situation with batch norm is a bit of a mess.
# aten.batch_norm has direct backend lowerings,
# aten.native_batch_norm gets decomposed into elementwise/reductions
# by DecomposeComplexOps (no backend marks it as backend-legal).
# Neither appears to support the "training" mode
# (the upstream decomposition we use here does), even though we have
# support for aten.native_batch_norm_backward.
aten._native_batch_norm_legit_functional,
])
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""Canonicalize single-element tuple returns to just the element.
def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
"""Canonicalize the calling convention to the one that Torch-MLIR supports.
The MLIR codebase currently supports importing functions that have either
a None return value, a single return value or a non-singleton tuple of
return values. But various situations create functions with single-element
tuples, or lists instead of tuples. This function adjusts the calling
conventions to match, and returns the information needed for the calling
code to reconstruct the original calling convention.
Returns:
True if unwrapping took place, and false otherwise.
Two booleans
- The first indicates if a single-element tuple/list return
was converted to a return of the element itself.
- The second indicates if a list return was converted to a tuple.
"""
did_unwrap = False
for node in fx_g.graph.nodes:
did_unwrap_single_element = False
did_convert_list_to_tuple = False
for node in gm.graph.nodes:
if node.op == "output":
assert len(
node.args) == 1, "Output node must have a single argument"
assert len(node.args) == 1, \
"Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
did_unwrap = True
did_unwrap_single_element = True
break
if isinstance(node_arg, list):
if len(node_arg) == 1:
node.args = (node_arg[0],)
did_unwrap_single_element = True
did_convert_list_to_tuple = True
break
else:
node.args= (tuple(node_arg),)
did_convert_list_to_tuple = True
break
if did_unwrap:
fx_g.graph.lint()
fx_g.recompile()
return did_unwrap
if did_unwrap_single_element:
gm.graph.lint()
gm.recompile()
return did_unwrap_single_element, did_convert_list_to_tuple
def make_simple_dynamo_backend(user_backend):
@ -78,16 +110,24 @@ def make_simple_dynamo_backend(user_backend):
Returns:
A function with the signature used by TorchDynamo backends.
"""
def wrapper_backend(fx_graph: torch.fx.GraphModule,
def wrapper_backend(gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
did_unwrap = _unwrap_single_tuple_return(fx_graph)
dispatcher_ops = make_fx(
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs)
strip_overloads(dispatcher_ops)
user_callable = user_backend(dispatcher_ops, example_inputs)
did_unwrap_single_element, did_convert_list_to_tuple = \
_adjust_calling_convention(gm)
strip_overloads(gm)
user_callable = user_backend(gm, example_inputs)
# TODO: Have a consistent story about the boxed calling convention.
# (for more details on this remove this decorator and look at the warning)
# See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
@functorch.compile.make_boxed_func
def dynamo_callable(*inputs):
result = user_callable(*inputs)
return (result,) if did_unwrap else result
if did_unwrap_single_element:
result = (result,)
if did_convert_list_to_tuple:
result = list(result)
return result
return dynamo_callable
return wrapper_backend
return aot_autograd(fw_compiler=wrapper_backend,
decompositions=_get_decomposition_table)