torch-mlir/python/torch_mlir/dynamo.py

94 lines
3.5 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import List
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from functorch._src.compile_utils import strip_overloads
from torch._decomp import get_decompositions
import warnings
# https://github.com/pytorch/pytorch/issues/89064
warnings.filterwarnings("ignore", module="torch.jit._check")
def _get_decomposition_table():
"""Get a decomposition table suitable for Torch-MLIR.
Sometimes TorchDynamo traces slightly different ops than what TorchScript
captures. Historically we have been driven by the ops captured by
TorchScript, so we try to decompose the ops captured by TorchDynamo into
other ops that we already support.
There isn't a highly principled solution here. Torch-MLIR currently supports
a somewhat random set of ops, added in a demand-driven way over time,
including direct backend support and decompositions internal to Torch-MLIR.
As described in the
[long-term roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/long_term_roadmap.md),
eventually this situation is expected to be made a lot more principled
by aligning more with how Torch-MLIR would have looked if some of the new
upstream PyTorch infra had been available at the beginning -- in particular
the new decomposition infra and PrimTorch.
"""
aten = torch.ops.aten
return get_decompositions([
aten._adaptive_avg_pool2d,
aten.std.correction,
aten.dot,
])
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""Canonicalize single-element tuple returns to just the element.
Returns:
True if unwrapping took place, and false otherwise.
"""
did_unwrap = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert len(
node.args) == 1, "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
did_unwrap = True
break
if did_unwrap:
fx_g.graph.lint()
fx_g.recompile()
return did_unwrap
def make_simple_dynamo_backend(user_backend):
"""Wrapper for functions intended to be used as TorchDynamo backends.
This function simplifies a few of the steps that are required to make
TorchDynamo work with Torch-MLIR.
Args:
user_backend: A function with the signature used by ordinary
TorchDynamo backends. But the torch.fx.GraphModule passed to it
will be normalized for consumption by `torch_mlir.compile`.
Returns:
A function with the signature used by TorchDynamo backends.
"""
def wrapper_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):
did_unwrap = _unwrap_single_tuple_return(fx_graph)
dispatcher_ops = make_fx(
fx_graph, decomposition_table=_get_decomposition_table())(*example_inputs)
strip_overloads(dispatcher_ops)
user_callable = user_backend(dispatcher_ops, example_inputs)
def dynamo_callable(*inputs):
result = user_callable(*inputs)
return (result,) if did_unwrap else result
return dynamo_callable
return wrapper_backend