2022-11-18 20:21:19 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import torch
|
2022-12-05 23:32:24 +08:00
|
|
|
from torch._functorch.compile_utils import strip_overloads
|
2022-11-29 22:01:42 +08:00
|
|
|
from torch._decomp import get_decompositions
|
2022-12-12 20:51:37 +08:00
|
|
|
from torch._dynamo.optimizations.training import aot_autograd
|
|
|
|
import functorch
|
2022-11-18 20:21:19 +08:00
|
|
|
|
|
|
|
import warnings
|
|
|
|
# https://github.com/pytorch/pytorch/issues/89064
|
|
|
|
warnings.filterwarnings("ignore", module="torch.jit._check")
|
|
|
|
|
|
|
|
|
2022-11-29 22:01:42 +08:00
|
|
|
def _get_decomposition_table():
|
|
|
|
"""Get a decomposition table suitable for Torch-MLIR.
|
|
|
|
|
|
|
|
Sometimes TorchDynamo traces slightly different ops than what TorchScript
|
|
|
|
captures. Historically we have been driven by the ops captured by
|
|
|
|
TorchScript, so we try to decompose the ops captured by TorchDynamo into
|
|
|
|
other ops that we already support.
|
|
|
|
|
|
|
|
There isn't a highly principled solution here. Torch-MLIR currently supports
|
|
|
|
a somewhat random set of ops, added in a demand-driven way over time,
|
|
|
|
including direct backend support and decompositions internal to Torch-MLIR.
|
|
|
|
As described in the
|
|
|
|
[long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md),
|
|
|
|
eventually this situation is expected to be made a lot more principled
|
|
|
|
by aligning more with how Torch-MLIR would have looked if some of the new
|
|
|
|
upstream PyTorch infra had been available at the beginning -- in particular
|
|
|
|
the new decomposition infra and PrimTorch.
|
|
|
|
"""
|
|
|
|
aten = torch.ops.aten
|
|
|
|
return get_decompositions([
|
|
|
|
aten._adaptive_avg_pool2d,
|
|
|
|
aten.std.correction,
|
|
|
|
aten.dot,
|
2022-12-12 20:51:37 +08:00
|
|
|
# TODO: Backends probably want to support this directly without
|
|
|
|
# decomposition.
|
|
|
|
# Our current situation with batch norm is a bit of a mess.
|
|
|
|
# aten.batch_norm has direct backend lowerings,
|
|
|
|
# aten.native_batch_norm gets decomposed into elementwise/reductions
|
|
|
|
# by DecomposeComplexOps (no backend marks it as backend-legal).
|
|
|
|
# Neither appears to support the "training" mode
|
|
|
|
# (the upstream decomposition we use here does), even though we have
|
|
|
|
# support for aten.native_batch_norm_backward.
|
|
|
|
aten._native_batch_norm_legit_functional,
|
2022-11-29 22:01:42 +08:00
|
|
|
])
|
|
|
|
|
|
|
|
|
2022-12-12 20:51:37 +08:00
|
|
|
def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool:
|
|
|
|
"""Canonicalize the calling convention to the one that Torch-MLIR supports.
|
|
|
|
|
|
|
|
The MLIR codebase currently supports importing functions that have either
|
|
|
|
a None return value, a single return value or a non-singleton tuple of
|
|
|
|
return values. But various situations create functions with single-element
|
|
|
|
tuples, or lists instead of tuples. This function adjusts the calling
|
|
|
|
conventions to match, and returns the information needed for the calling
|
|
|
|
code to reconstruct the original calling convention.
|
2022-11-18 20:21:19 +08:00
|
|
|
|
|
|
|
Returns:
|
2022-12-12 20:51:37 +08:00
|
|
|
Two booleans
|
|
|
|
- The first indicates if a single-element tuple/list return
|
|
|
|
was converted to a return of the element itself.
|
|
|
|
- The second indicates if a list return was converted to a tuple.
|
2022-11-18 20:21:19 +08:00
|
|
|
"""
|
2022-12-12 20:51:37 +08:00
|
|
|
did_unwrap_single_element = False
|
|
|
|
did_convert_list_to_tuple = False
|
|
|
|
for node in gm.graph.nodes:
|
2022-11-18 20:21:19 +08:00
|
|
|
if node.op == "output":
|
2022-12-12 20:51:37 +08:00
|
|
|
assert len(node.args) == 1, \
|
|
|
|
"Output node must have a single argument"
|
2022-11-18 20:21:19 +08:00
|
|
|
node_arg = node.args[0]
|
|
|
|
if isinstance(node_arg, tuple):
|
|
|
|
if len(node_arg) == 1:
|
|
|
|
node.args = (node_arg[0],)
|
2022-12-12 20:51:37 +08:00
|
|
|
did_unwrap_single_element = True
|
|
|
|
break
|
|
|
|
if isinstance(node_arg, list):
|
|
|
|
if len(node_arg) == 1:
|
|
|
|
node.args = (node_arg[0],)
|
|
|
|
did_unwrap_single_element = True
|
|
|
|
did_convert_list_to_tuple = True
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
node.args= (tuple(node_arg),)
|
|
|
|
did_convert_list_to_tuple = True
|
2022-11-18 20:21:19 +08:00
|
|
|
break
|
|
|
|
|
2022-12-12 20:51:37 +08:00
|
|
|
if did_unwrap_single_element:
|
|
|
|
gm.graph.lint()
|
|
|
|
gm.recompile()
|
|
|
|
return did_unwrap_single_element, did_convert_list_to_tuple
|
2022-11-18 20:21:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
def make_simple_dynamo_backend(user_backend):
|
|
|
|
"""Wrapper for functions intended to be used as TorchDynamo backends.
|
|
|
|
|
|
|
|
This function simplifies a few of the steps that are required to make
|
|
|
|
TorchDynamo work with Torch-MLIR.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
user_backend: A function with the signature used by ordinary
|
|
|
|
TorchDynamo backends. But the torch.fx.GraphModule passed to it
|
|
|
|
will be normalized for consumption by `torch_mlir.compile`.
|
|
|
|
Returns:
|
|
|
|
A function with the signature used by TorchDynamo backends.
|
|
|
|
"""
|
2022-12-12 20:51:37 +08:00
|
|
|
def wrapper_backend(gm: torch.fx.GraphModule,
|
2022-11-18 20:21:19 +08:00
|
|
|
example_inputs: List[torch.Tensor]):
|
2022-12-12 20:51:37 +08:00
|
|
|
did_unwrap_single_element, did_convert_list_to_tuple = \
|
|
|
|
_adjust_calling_convention(gm)
|
|
|
|
strip_overloads(gm)
|
|
|
|
user_callable = user_backend(gm, example_inputs)
|
2022-11-18 20:21:19 +08:00
|
|
|
|
2022-12-12 20:51:37 +08:00
|
|
|
# TODO: Have a consistent story about the boxed calling convention.
|
|
|
|
# (for more details on this remove this decorator and look at the warning)
|
|
|
|
# See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
|
|
|
|
@functorch.compile.make_boxed_func
|
2022-11-18 20:21:19 +08:00
|
|
|
def dynamo_callable(*inputs):
|
|
|
|
result = user_callable(*inputs)
|
2022-12-12 20:51:37 +08:00
|
|
|
if did_unwrap_single_element:
|
|
|
|
result = (result,)
|
|
|
|
if did_convert_list_to_tuple:
|
|
|
|
result = list(result)
|
|
|
|
return result
|
2022-11-18 20:21:19 +08:00
|
|
|
return dynamo_callable
|
2022-12-12 20:51:37 +08:00
|
|
|
return aot_autograd(fw_compiler=wrapper_backend,
|
|
|
|
decompositions=_get_decomposition_table)
|