mirror of https://github.com/llvm/torch-mlir
[torchdynamo] Move to aot_autograd instead of raw make_fx
As [@ezyang suggested](https://github.com/pytorch/pytorch/issues/90276#issuecomment-1339791275), use `torch._dynamo.optimizations.training.aot_autograd` instead of raw `make_fx`. This is more future proof and gives us the backward pass and functionalization. We don't currently get functionalization because of https://github.com/pytorch/pytorch/issues/90759 This also incidentally fixes the source location handling, which makes `lockstep_basic.py` give an accurate source location!pull/1724/head snapshot-20221215.688
parent
64f9a0e978
commit
af9e8a5e63
|
@ -102,34 +102,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"UniformModule_basic",
|
"UniformModule_basic",
|
||||||
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
|
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
|
||||||
"TModuleRank1_basic",
|
"TModuleRank1_basic",
|
||||||
# error:
|
|
||||||
"BatchMlpLayerModule_basic",
|
|
||||||
"BatchNorm1DModule_basic",
|
|
||||||
"BatchNorm1DWith2DInputModule_basic",
|
|
||||||
"BatchNorm2DModule_basic",
|
|
||||||
"BatchNorm3DModule_basic",
|
|
||||||
"Conv2dBiasNoPaddingModule_basic",
|
|
||||||
"Conv2dNoPaddingModule_basic",
|
|
||||||
"Conv2dWithPaddingDilationStrideModule_basic",
|
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
|
||||||
"Conv2dWithPaddingModule_basic",
|
|
||||||
"EmbeddingModule1DIndices_basic",
|
|
||||||
"EmbeddingModuleI32Static_basic",
|
|
||||||
"EmbeddingModuleI32_basic",
|
|
||||||
"EmbeddingModuleI64_basic",
|
|
||||||
"HBC_basic",
|
|
||||||
"LayerNormLastDimModule_basic",
|
|
||||||
"LayerNormModule_basic",
|
|
||||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
|
||||||
"Mlp1LayerModule_basic",
|
|
||||||
"Mlp2LayerModuleNoBias_basic",
|
|
||||||
"Mlp2LayerModule_basic",
|
|
||||||
"MobilenetV3Module_basic",
|
|
||||||
"ResNet18Module_basic",
|
|
||||||
"ResNet18StaticModule_basic",
|
|
||||||
"SliceEndSleStartModule_basic",
|
|
||||||
"SliceOutOfUpperBoundIndexModule_basic",
|
|
||||||
"SliceStartEqEndModule_basic",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MHLO_PASS_SET = {
|
MHLO_PASS_SET = {
|
||||||
|
|
|
@ -30,7 +30,9 @@ def miscompile_div_as_mul_backend(gm: torch.fx.GraphModule,
|
||||||
|
|
||||||
# TODO: As we get smarter about making this output more readable, we should
|
# TODO: As we get smarter about making this output more readable, we should
|
||||||
# have more focused tests rather than this "check the exact output" test.
|
# have more focused tests rather than this "check the exact output" test.
|
||||||
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at None
|
# CHECK: User result tensor([ 4., 10., 18.]) is not close to golden result tensor([0.2500, 0.4000, 0.5000]) for node div at Module stack: {}
|
||||||
|
# CHECK-NEXT: File "{{.*}}python/test/debug/lockstep_basic.py", line {{.*}}, in f
|
||||||
|
# CHECK-NEXT: c = x / y
|
||||||
@dynamo.optimize(miscompile_div_as_mul_backend)
|
@dynamo.optimize(miscompile_div_as_mul_backend)
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
a = x * y
|
a = x * y
|
||||||
|
|
|
@ -6,9 +6,10 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
|
||||||
from torch._functorch.compile_utils import strip_overloads
|
from torch._functorch.compile_utils import strip_overloads
|
||||||
from torch._decomp import get_decompositions
|
from torch._decomp import get_decompositions
|
||||||
|
from torch._dynamo.optimizations.training import aot_autograd
|
||||||
|
import functorch
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
# https://github.com/pytorch/pytorch/issues/89064
|
# https://github.com/pytorch/pytorch/issues/89064
|
||||||
|
@ -38,31 +39,62 @@ def _get_decomposition_table():
|
||||||
aten._adaptive_avg_pool2d,
|
aten._adaptive_avg_pool2d,
|
||||||
aten.std.correction,
|
aten.std.correction,
|
||||||
aten.dot,
|
aten.dot,
|
||||||
|
# TODO: Backends probably want to support this directly without
|
||||||
|
# decomposition.
|
||||||
|
# Our current situation with batch norm is a bit of a mess.
|
||||||
|
# aten.batch_norm has direct backend lowerings,
|
||||||
|
# aten.native_batch_norm gets decomposed into elementwise/reductions
|
||||||
|
# by DecomposeComplexOps (no backend marks it as backend-legal).
|
||||||
|
# Neither appears to support the "training" mode
|
||||||
|
# (the upstream decomposition we use here does), even though we have
|
||||||
|
# support for aten.native_batch_norm_backward.
|
||||||
|
aten._native_batch_norm_legit_functional,
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
||||||
"""Canonicalize single-element tuple returns to just the element.
|
"""Canonicalize the calling convention to the one that Torch-MLIR supports.
|
||||||
|
|
||||||
|
The MLIR codebase currently supports importing functions that have either
|
||||||
|
a None return value, a single return value or a non-singleton tuple of
|
||||||
|
return values. But various situations create functions with single-element
|
||||||
|
tuples, or lists instead of tuples. This function adjusts the calling
|
||||||
|
conventions to match, and returns the information needed for the calling
|
||||||
|
code to reconstruct the original calling convention.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if unwrapping took place, and false otherwise.
|
Two booleans
|
||||||
|
- The first indicates if a single-element tuple/list return
|
||||||
|
was converted to a return of the element itself.
|
||||||
|
- The second indicates if a list return was converted to a tuple.
|
||||||
"""
|
"""
|
||||||
did_unwrap = False
|
did_unwrap_single_element = False
|
||||||
for node in fx_g.graph.nodes:
|
did_convert_list_to_tuple = False
|
||||||
|
for node in gm.graph.nodes:
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
assert len(
|
assert len(node.args) == 1, \
|
||||||
node.args) == 1, "Output node must have a single argument"
|
"Output node must have a single argument"
|
||||||
node_arg = node.args[0]
|
node_arg = node.args[0]
|
||||||
if isinstance(node_arg, tuple):
|
if isinstance(node_arg, tuple):
|
||||||
if len(node_arg) == 1:
|
if len(node_arg) == 1:
|
||||||
node.args = (node_arg[0],)
|
node.args = (node_arg[0],)
|
||||||
did_unwrap = True
|
did_unwrap_single_element = True
|
||||||
|
break
|
||||||
|
if isinstance(node_arg, list):
|
||||||
|
if len(node_arg) == 1:
|
||||||
|
node.args = (node_arg[0],)
|
||||||
|
did_unwrap_single_element = True
|
||||||
|
did_convert_list_to_tuple = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
node.args= (tuple(node_arg),)
|
||||||
|
did_convert_list_to_tuple = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if did_unwrap:
|
if did_unwrap_single_element:
|
||||||
fx_g.graph.lint()
|
gm.graph.lint()
|
||||||
fx_g.recompile()
|
gm.recompile()
|
||||||
return did_unwrap
|
return did_unwrap_single_element, did_convert_list_to_tuple
|
||||||
|
|
||||||
|
|
||||||
def make_simple_dynamo_backend(user_backend):
|
def make_simple_dynamo_backend(user_backend):
|
||||||
|
@ -78,16 +110,24 @@ def make_simple_dynamo_backend(user_backend):
|
||||||
Returns:
|
Returns:
|
||||||
A function with the signature used by TorchDynamo backends.
|
A function with the signature used by TorchDynamo backends.
|
||||||
"""
|
"""
|
||||||
def wrapper_backend(fx_graph: torch.fx.GraphModule,
|
def wrapper_backend(gm: torch.fx.GraphModule,
|
||||||
example_inputs: List[torch.Tensor]):
|
example_inputs: List[torch.Tensor]):
|
||||||
did_unwrap = _unwrap_single_tuple_return(fx_graph)
|
did_unwrap_single_element, did_convert_list_to_tuple = \
|
||||||
dispatcher_ops = make_fx(
|
_adjust_calling_convention(gm)
|
||||||
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs)
|
strip_overloads(gm)
|
||||||
strip_overloads(dispatcher_ops)
|
user_callable = user_backend(gm, example_inputs)
|
||||||
user_callable = user_backend(dispatcher_ops, example_inputs)
|
|
||||||
|
|
||||||
|
# TODO: Have a consistent story about the boxed calling convention.
|
||||||
|
# (for more details on this remove this decorator and look at the warning)
|
||||||
|
# See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
|
||||||
|
@functorch.compile.make_boxed_func
|
||||||
def dynamo_callable(*inputs):
|
def dynamo_callable(*inputs):
|
||||||
result = user_callable(*inputs)
|
result = user_callable(*inputs)
|
||||||
return (result,) if did_unwrap else result
|
if did_unwrap_single_element:
|
||||||
|
result = (result,)
|
||||||
|
if did_convert_list_to_tuple:
|
||||||
|
result = list(result)
|
||||||
|
return result
|
||||||
return dynamo_callable
|
return dynamo_callable
|
||||||
return wrapper_backend
|
return aot_autograd(fw_compiler=wrapper_backend,
|
||||||
|
decompositions=_get_decomposition_table)
|
||||||
|
|
Loading…
Reference in New Issue