mirror of https://github.com/llvm/torch-mlir
[torchdynamo] Move to aot_autograd instead of raw make_fx
As [@ezyang suggested](https://github.com/pytorch/pytorch/issues/90276#issuecomment-1339791275), use `torch._dynamo.optimizations.training.aot_autograd` instead of raw `make_fx`. This is more future proof and gives us the backward pass and functionalization. We don't currently get functionalization because of https://github.com/pytorch/pytorch/issues/90759 This also incidentally fixes the source location handling, which makes `lockstep_basic.py` give an accurate source location!pull/1724/head snapshot-20221215.688
parent
64f9a0e978
commit
af9e8a5e63
|
@ -102,34 +102,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"UniformModule_basic",
|
||||
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
|
||||
"TModuleRank1_basic",
|
||||
# error:
|
||||
"BatchMlpLayerModule_basic",
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"Conv2dBiasNoPaddingModule_basic",
|
||||
"Conv2dNoPaddingModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Conv2dWithPaddingModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"HBC_basic",
|
||||
"LayerNormLastDimModule_basic",
|
||||
"LayerNormModule_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
"Mlp1LayerModule_basic",
|
||||
"Mlp2LayerModuleNoBias_basic",
|
||||
"Mlp2LayerModule_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
"ResNet18Module_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
}
|
||||
|
||||
MHLO_PASS_SET = {
|
||||
|
|
|
@ -30,7 +30,9 @@ def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule,
|
|||
|
||||
# TODO: As we get smarter about making this output more readable, we should
|
||||
# have more focused tests rather than this "check the exact output" test.
|
||||
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at None
|
||||
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at Module stack: {}
|
||||
# CHECK-NEXT: File "{{.*}}python/test/debug/lockstep_basic.py", line {{.*}}, in f
|
||||
# CHECK-NEXT: c = x / y
|
||||
@dynamo.optimize(miscompile_div_as_mul_backend)
|
||||
def f(x, y):
|
||||
a = x * y
|
||||
|
|
|
@ -6,9 +6,10 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from torch._decomp import get_decompositions
|
||||
from torch._dynamo.optimizations.training import aot_autograd
|
||||
import functorch
|
||||
|
||||
import warnings
|
||||
# https://github.com/pytorch/pytorch/issues/89064
|
||||
|
@ -38,31 +39,62 @@ def _get_decomposition_table():
|
|||
aten._adaptive_avg_pool2d,
|
||||
aten.std.correction,
|
||||
aten.dot,
|
||||
# TODO: Backends probably want to support this directly without
|
||||
# decomposition.
|
||||
# Our current situation with batch norm is a bit of a mess.
|
||||
# aten.batch_norm has direct backend lowerings,
|
||||
# aten.native_batch_norm gets decomposed into elementwise/reductions
|
||||
# by DecomposeComplexOps (no backend marks it as backend-legal).
|
||||
# Neither appears to support the "training" mode
|
||||
# (the upstream decomposition we use here does), even though we have
|
||||
# support for aten.native_batch_norm_backward.
|
||||
aten._native_batch_norm_legit_functional,
|
||||
])
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""Canonicalize single-element tuple returns to just the element.
|
||||
def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
||||
"""Canonicalize the calling convention to the one that Torch-MLIR supports.
|
||||
|
||||
The MLIR codebase currently supports importing functions that have either
|
||||
a None return value, a single return value or a non-singleton tuple of
|
||||
return values. But various situations create functions with single-element
|
||||
tuples, or lists instead of tuples. This function adjusts the calling
|
||||
conventions to match, and returns the information needed for the calling
|
||||
code to reconstruct the original calling convention.
|
||||
|
||||
Returns:
|
||||
True if unwrapping took place, and false otherwise.
|
||||
Two booleans
|
||||
- The first indicates if a single-element tuple/list return
|
||||
was converted to a return of the element itself.
|
||||
- The second indicates if a list return was converted to a tuple.
|
||||
"""
|
||||
did_unwrap = False
|
||||
for node in fx_g.graph.nodes:
|
||||
did_unwrap_single_element = False
|
||||
did_convert_list_to_tuple = False
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert len(
|
||||
node.args) == 1, "Output node must have a single argument"
|
||||
assert len(node.args) == 1, \
|
||||
"Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
did_unwrap = True
|
||||
did_unwrap_single_element = True
|
||||
break
|
||||
if isinstance(node_arg, list):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
did_unwrap_single_element = True
|
||||
did_convert_list_to_tuple = True
|
||||
break
|
||||
else:
|
||||
node.args= (tuple(node_arg),)
|
||||
did_convert_list_to_tuple = True
|
||||
break
|
||||
|
||||
if did_unwrap:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return did_unwrap
|
||||
if did_unwrap_single_element:
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
return did_unwrap_single_element, did_convert_list_to_tuple
|
||||
|
||||
|
||||
def make_simple_dynamo_backend(user_backend):
|
||||
|
@ -78,16 +110,24 @@ def make_simple_dynamo_backend(user_backend):
|
|||
Returns:
|
||||
A function with the signature used by TorchDynamo backends.
|
||||
"""
|
||||
def wrapper_backend(fx_graph: torch.fx.GraphModule,
|
||||
def wrapper_backend(gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
did_unwrap = _unwrap_single_tuple_return(fx_graph)
|
||||
dispatcher_ops = make_fx(
|
||||
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs)
|
||||
strip_overloads(dispatcher_ops)
|
||||
user_callable = user_backend(dispatcher_ops, example_inputs)
|
||||
did_unwrap_single_element, did_convert_list_to_tuple = \
|
||||
_adjust_calling_convention(gm)
|
||||
strip_overloads(gm)
|
||||
user_callable = user_backend(gm, example_inputs)
|
||||
|
||||
# TODO: Have a consistent story about the boxed calling convention.
|
||||
# (for more details on this remove this decorator and look at the warning)
|
||||
# See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
|
||||
@functorch.compile.make_boxed_func
|
||||
def dynamo_callable(*inputs):
|
||||
result = user_callable(*inputs)
|
||||
return (result,) if did_unwrap else result
|
||||
if did_unwrap_single_element:
|
||||
result = (result,)
|
||||
if did_convert_list_to_tuple:
|
||||
result = list(result)
|
||||
return result
|
||||
return dynamo_callable
|
||||
return wrapper_backend
|
||||
return aot_autograd(fw_compiler=wrapper_backend,
|
||||
decompositions=_get_decomposition_table)
|
||||
|
|
Loading…
Reference in New Issue