# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. from typing import List import torch from torch._functorch.compile_utils import strip_overloads from torch._decomp import get_decompositions from torch._dynamo.optimizations.training import aot_autograd import functorch import warnings # https://github.com/pytorch/pytorch/issues/89064 warnings.filterwarnings("ignore", module="torch.jit._check") def _get_decomposition_table(): """Get a decomposition table suitable for Torch-MLIR. Sometimes TorchDynamo traces slightly different ops than what TorchScript captures. Historically we have been driven by the ops captured by TorchScript, so we try to decompose the ops captured by TorchDynamo into other ops that we already support. There isn't a highly principled solution here. Torch-MLIR currently supports a somewhat random set of ops, added in a demand-driven way over time, including direct backend support and decompositions internal to Torch-MLIR. As described in the [long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md), eventually this situation is expected to be made a lot more principled by aligning more with how Torch-MLIR would have looked if some of the new upstream PyTorch infra had been available at the beginning -- in particular the new decomposition infra and PrimTorch. """ aten = torch.ops.aten return get_decompositions([ aten._adaptive_avg_pool2d, aten.std.correction, aten.dot, # TODO: Backends probably want to support this directly without # decomposition. # Our current situation with batch norm is a bit of a mess. # aten.batch_norm has direct backend lowerings, # aten.native_batch_norm gets decomposed into elementwise/reductions # by DecomposeComplexOps (no backend marks it as backend-legal). # Neither appears to support the "training" mode # (the upstream decomposition we use here does), even though we have # support for aten.native_batch_norm_backward. aten._native_batch_norm_legit_functional, aten.native_group_norm, aten.split.Tensor, aten.split_with_sizes, ]) def _adjust_calling_convention(gm: torch.fx.GraphModule) -> bool: """Canonicalize the calling convention to the one that Torch-MLIR supports. The MLIR codebase currently supports importing functions that have either a None return value, a single return value or a non-singleton tuple of return values. But various situations create functions with single-element tuples, or lists instead of tuples. This function adjusts the calling conventions to match, and returns the information needed for the calling code to reconstruct the original calling convention. Returns: Two booleans - The first indicates if a single-element tuple/list return was converted to a return of the element itself. - The second indicates if a list return was converted to a tuple. """ did_unwrap_single_element = False did_convert_list_to_tuple = False for node in gm.graph.nodes: if node.op == "output": assert len(node.args) == 1, \ "Output node must have a single argument" node_arg = node.args[0] if isinstance(node_arg, tuple): if len(node_arg) == 1: node.args = (node_arg[0],) did_unwrap_single_element = True break if isinstance(node_arg, list): if len(node_arg) == 1: node.args = (node_arg[0],) did_unwrap_single_element = True did_convert_list_to_tuple = True break else: node.args= (tuple(node_arg),) did_convert_list_to_tuple = True break if did_unwrap_single_element: gm.graph.lint() gm.recompile() return did_unwrap_single_element, did_convert_list_to_tuple def make_simple_dynamo_backend(user_backend): """Wrapper for functions intended to be used as TorchDynamo backends. This function simplifies a few of the steps that are required to make TorchDynamo work with Torch-MLIR. Args: user_backend: A function with the signature used by ordinary TorchDynamo backends. But the torch.fx.GraphModule passed to it will be normalized for consumption by `torch_mlir.compile`. Returns: A function with the signature used by TorchDynamo backends. """ def wrapper_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): did_unwrap_single_element, did_convert_list_to_tuple = \ _adjust_calling_convention(gm) strip_overloads(gm) user_callable = user_backend(gm, example_inputs) # TODO: Have a consistent story about the boxed calling convention. # (for more details on this remove this decorator and look at the warning) # See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale. @functorch.compile.make_boxed_func def dynamo_callable(*inputs): result = user_callable(*inputs) if did_unwrap_single_element: result = (result,) if did_convert_list_to_tuple: result = list(result) return result return dynamo_callable return aot_autograd(fw_compiler=wrapper_backend, decompositions=_get_decomposition_table)