2022-04-20 08:30:09 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
2023-03-31 00:20:19 +08:00
|
|
|
from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable
|
2022-04-20 08:30:09 +08:00
|
|
|
from enum import Enum
|
|
|
|
|
2022-08-19 08:01:54 +08:00
|
|
|
import sys
|
|
|
|
from io import StringIO
|
2023-03-31 00:20:19 +08:00
|
|
|
import tempfile
|
2024-04-24 11:34:02 +08:00
|
|
|
import os
|
2022-08-19 08:01:54 +08:00
|
|
|
|
2022-12-05 23:32:24 +08:00
|
|
|
from torch._functorch.compile_utils import strip_overloads
|
2022-04-20 08:30:09 +08:00
|
|
|
import torch
|
2023-05-12 13:46:33 +08:00
|
|
|
import torch.fx
|
2023-07-13 21:07:54 +08:00
|
|
|
from torch_mlir.dynamo import _get_decomposition_table
|
|
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
2022-04-20 08:30:09 +08:00
|
|
|
|
2024-04-27 18:27:37 +08:00
|
|
|
from torch_mlir.compiler_utils import (
|
|
|
|
run_pipeline_with_repro_report,
|
|
|
|
OutputType,
|
2024-04-28 05:16:31 +08:00
|
|
|
lower_mlir_module,
|
2024-08-27 23:31:28 +08:00
|
|
|
TensorPlaceholder,
|
2024-04-27 18:27:37 +08:00
|
|
|
)
|
2023-11-20 04:10:19 +08:00
|
|
|
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
|
|
|
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
2022-04-20 08:30:09 +08:00
|
|
|
|
|
|
|
|
2022-11-09 19:15:54 +08:00
|
|
|
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
|
|
|
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
|
|
|
|
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]
|
|
|
|
|
|
|
|
|
|
|
|
class ExampleArgs:
|
|
|
|
"""A class representing the example arguments to an nn.Module.
|
|
|
|
|
|
|
|
In general, an nn.Module may have multiple methods that need to be compiled.
|
|
|
|
This requires example args for each method. This class is a lightweight
|
|
|
|
wrapper around a dictionary that maps method names to example arguments.
|
|
|
|
|
|
|
|
In user-facing API's, this type can always be passed interchangeably with a
|
|
|
|
single arg or list of args, which normalizes to an ExampleArgs for just
|
|
|
|
the `forward` method via the `ExampleArgs.get` static method.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._example_args = {}
|
|
|
|
|
|
|
|
def add_method(self, method_name: str, example_args: _example_args_for_one_method):
|
|
|
|
"""Adds example args for a method.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
method_name: The name of the method. Must have not already been
|
|
|
|
added previously as a method.
|
|
|
|
example_args: The example args for the method.
|
|
|
|
Returns:
|
|
|
|
self, for chaining.
|
|
|
|
"""
|
|
|
|
assert method_name not in self._example_args
|
2024-04-28 05:16:31 +08:00
|
|
|
self._example_args[method_name] = ExampleArgs._canonicalize_args(example_args)
|
2022-11-09 19:15:54 +08:00
|
|
|
return self
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get(example_args: _example_args) -> "ExampleArgs":
|
|
|
|
"""Gets an ExampleArgs from one of the permissible ways to specify one.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
example_args: An ExampleArgs instance or a single arg or list of args.
|
|
|
|
Returns:
|
|
|
|
An ExampleArgs instance.
|
|
|
|
"""
|
|
|
|
if isinstance(example_args, ExampleArgs):
|
|
|
|
return example_args
|
|
|
|
return ExampleArgs().add_method("forward", example_args)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _canonicalize_args(example_args: _example_args_for_one_method):
|
|
|
|
"""Canonicalize the args for one method into a tuple."""
|
|
|
|
if not isinstance(example_args, Sequence):
|
|
|
|
example_args = [example_args]
|
|
|
|
for arg in example_args:
|
2022-11-11 20:00:17 +08:00
|
|
|
if not isinstance(arg, (TensorPlaceholder, torch.Tensor)):
|
2024-04-28 05:16:31 +08:00
|
|
|
raise Exception(
|
|
|
|
f"Only Tensor's, TensorPlaceholder's, or sequences of "
|
|
|
|
f"Tensor's and TensorPlaceholder's are supported as "
|
|
|
|
f"example args for method inputs. "
|
|
|
|
f"Got '{arg}'."
|
|
|
|
)
|
2022-11-09 19:15:54 +08:00
|
|
|
return tuple(example_args)
|
|
|
|
|
|
|
|
def _get_methods(self):
|
|
|
|
return self._example_args.keys()
|
|
|
|
|
|
|
|
def _get_for_annotation(self):
|
|
|
|
result = {}
|
|
|
|
for method_name, example_args in self._example_args.items():
|
|
|
|
placeholders = []
|
|
|
|
for arg in example_args:
|
|
|
|
if isinstance(arg, TensorPlaceholder):
|
|
|
|
placeholders.append(arg)
|
|
|
|
else:
|
|
|
|
assert isinstance(arg, torch.Tensor)
|
|
|
|
placeholders.append(TensorPlaceholder.like(arg))
|
|
|
|
result[method_name] = placeholders
|
|
|
|
return result
|
|
|
|
|
|
|
|
def _get_for_tracing(
|
|
|
|
self,
|
|
|
|
use_tracing: bool,
|
|
|
|
ignore_traced_shapes: bool,
|
|
|
|
) -> Dict[str, Tuple[_example_arg, ...]]:
|
|
|
|
result = {}
|
|
|
|
for method_name, example_args in self._example_args.items():
|
|
|
|
# If we are tracing, then we need to convert any placeholders into
|
|
|
|
# concrete values.
|
|
|
|
if use_tracing:
|
|
|
|
example_args_for_trace = []
|
|
|
|
for arg in example_args:
|
|
|
|
if isinstance(arg, TensorPlaceholder):
|
|
|
|
if not ignore_traced_shapes:
|
|
|
|
# To avoid accidental footguns, we require
|
|
|
|
# `ignore_traced_shapes` to be true if we're using
|
|
|
|
# TensorPlaceholder's, as it falls into the same
|
|
|
|
# "hopefully the trace works for different inputs"
|
|
|
|
# bucket of concerns.
|
|
|
|
raise Exception(
|
2024-04-28 05:16:31 +08:00
|
|
|
"TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True`"
|
|
|
|
)
|
2022-11-09 19:15:54 +08:00
|
|
|
# For any dynamic dimensions, replace them with "7"
|
|
|
|
# arbitrarily. If a user is using dynamic dimensions with
|
|
|
|
# tracing, they are walking on thin ice already -- assume
|
|
|
|
# they know what they are doing and that their trace is
|
|
|
|
# correct for any specific concrete size.
|
2022-11-24 21:02:59 +08:00
|
|
|
shape = [s if s != -1 else 7 for s in arg.shape]
|
2023-07-13 21:07:54 +08:00
|
|
|
if len(shape) == 0:
|
|
|
|
example_args_for_trace.append(torch.tensor(1))
|
|
|
|
else:
|
|
|
|
example_args_for_trace.append(
|
2024-04-28 05:16:31 +08:00
|
|
|
torch.ones(*shape, dtype=arg.dtype)
|
|
|
|
)
|
2022-11-09 19:15:54 +08:00
|
|
|
else:
|
|
|
|
assert isinstance(arg, torch.Tensor)
|
|
|
|
example_args_for_trace.append(arg)
|
|
|
|
example_args = tuple(example_args_for_trace)
|
|
|
|
result[method_name] = example_args
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2022-08-19 08:01:54 +08:00
|
|
|
# The set of ops that are considered legal for each backend.
|
|
|
|
# These are currently quite load-bearing, since different backends might be
|
|
|
|
# missing patterns for decomposed forms of certain ops.
|
|
|
|
# TODO: Tighten up the definition of these "conditionally legal for backends"
|
|
|
|
# ops in the backend contract, and move these lists somewhere deeper in the
|
|
|
|
# compiler where each backend can "own" its set of legal ops.
|
|
|
|
BACKEND_LEGAL_OPS = {
|
2024-04-28 05:16:31 +08:00
|
|
|
OutputType.TOSA: [
|
|
|
|
"aten.flatten.using_ints",
|
|
|
|
"aten.native_layer_norm",
|
|
|
|
"aten.linear",
|
|
|
|
],
|
|
|
|
OutputType.LINALG_ON_TENSORS: [
|
|
|
|
"aten.flatten.using_ints",
|
|
|
|
"aten.adaptive_avg_pool1d",
|
|
|
|
"aten.adaptive_avg_pool2d",
|
|
|
|
"aten.unflatten.int",
|
|
|
|
],
|
2024-05-23 20:40:20 +08:00
|
|
|
OutputType.STABLEHLO: [
|
|
|
|
"aten.amax",
|
|
|
|
"aten.amin",
|
|
|
|
"aten.randn.generator",
|
|
|
|
"aten.normal_functional",
|
|
|
|
],
|
2022-08-19 08:01:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
def _canon_extra_library(
|
|
|
|
extra_library, extra_library_file_name="custom_op_extra_library.mlir"
|
|
|
|
):
|
2023-05-12 13:46:33 +08:00
|
|
|
if len(extra_library) != 0:
|
|
|
|
extra_library_dict = {}
|
|
|
|
for library_func in extra_library:
|
|
|
|
extra_library_dict[library_func.__name__] = library_func
|
|
|
|
mlir_library = generate_library(extra_library_dict)
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
extra_library_file = os.path.join(
|
|
|
|
tempfile.gettempdir(), extra_library_file_name
|
|
|
|
)
|
2024-04-24 11:34:02 +08:00
|
|
|
with open(extra_library_file, "w") as f:
|
2023-05-12 13:46:33 +08:00
|
|
|
f.write(mlir_library)
|
2024-04-24 11:34:02 +08:00
|
|
|
return extra_library_file
|
|
|
|
else:
|
|
|
|
return ""
|
2023-05-12 13:46:33 +08:00
|
|
|
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
def compile(
|
|
|
|
model: torch.nn.Module,
|
|
|
|
example_args: _example_args,
|
|
|
|
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
|
|
|
use_tracing: bool = False,
|
|
|
|
ignore_traced_shapes=False,
|
|
|
|
backend_legal_ops: Optional[Sequence[str]] = None,
|
|
|
|
extra_library: Iterable[Callable] = [],
|
|
|
|
verbose: bool = False,
|
|
|
|
use_make_fx: bool = False,
|
|
|
|
enable_ir_printing: bool = False,
|
|
|
|
):
|
2022-04-20 08:30:09 +08:00
|
|
|
"""Convert a PyTorch model to MLIR.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: The PyTorch model to convert.
|
|
|
|
example_args: A list of example arguments to use when inferring the
|
|
|
|
shapes of the arguments to `forward` method of the model.
|
|
|
|
A single tensor is treated as a list of a single tensor.
|
2022-05-16 17:13:10 +08:00
|
|
|
A TensorPlaceholder object is also allowed in the place of any
|
2022-11-09 19:15:54 +08:00
|
|
|
Tensor. For models with multiple methods, an `ExampleArgs` object
|
|
|
|
can be passed.
|
2022-04-20 08:30:09 +08:00
|
|
|
output_type: The kind of output to produce. See `OutputType` for more
|
|
|
|
details.
|
2022-05-03 16:43:13 +08:00
|
|
|
use_tracing: If True, use `torch.jit.trace` to convert the model to
|
|
|
|
JIT IR rather than `torch.jit.script`.
|
2022-08-04 07:30:57 +08:00
|
|
|
ignore_traced_shapes: If True, ignore the shapes that were observed
|
|
|
|
during tracing. This should only be used if one knows that the
|
|
|
|
original traced program would result in the same trace (modulo
|
|
|
|
shapes) for all shape combinations implied by any
|
|
|
|
`TensorPlaceholder`'s used as `example_args`. Also,
|
|
|
|
strictly-speaking, this option covers dtypes too, but we just say
|
|
|
|
"shapes" to be succinct.
|
2022-10-11 23:44:16 +08:00
|
|
|
backend_legal_ops: A list of ops that should be considered legal for
|
|
|
|
the backend. An op that is considered legal will not be decomposed.
|
|
|
|
This option is only valid with the `"torch"` output type.
|
2023-03-31 00:20:19 +08:00
|
|
|
extra_library: List of abstract interpretation functions to splice
|
|
|
|
into the abstract interpretation library. See
|
|
|
|
`docs/adding_abstract_interpretation_functions.md` for more info
|
|
|
|
on the format the functions should have.
|
2023-12-21 05:08:21 +08:00
|
|
|
verbose: If true, print extra information about the conversion to
|
|
|
|
stdout.
|
|
|
|
enable_ir_printing: If true, print the IR before and after each pass to
|
|
|
|
stderr. This is equivalent to setting MLIR's `-print-ir-after-all`
|
|
|
|
flag. Note that this can easily generate many gigabytes of text,
|
|
|
|
so make sure to pipe stderr to a file (for example, run
|
|
|
|
`python tinymodel.py 2> tinymodel.stderr` on Linux).
|
2022-05-16 17:13:10 +08:00
|
|
|
|
2022-04-20 08:30:09 +08:00
|
|
|
Returns:
|
|
|
|
An MLIR module that contains the converted model in the specified
|
|
|
|
output type.
|
|
|
|
"""
|
2023-05-12 13:46:33 +08:00
|
|
|
extra_library_file_name = _canon_extra_library(extra_library)
|
2022-07-09 05:42:30 +08:00
|
|
|
output_type = OutputType.get(output_type)
|
2022-11-09 19:15:54 +08:00
|
|
|
example_args = ExampleArgs.get(example_args)
|
2022-08-04 07:30:57 +08:00
|
|
|
if ignore_traced_shapes and not use_tracing:
|
|
|
|
raise Exception("`ignore_traced_shapes` requires `use_tracing`")
|
2022-04-20 08:30:09 +08:00
|
|
|
|
2022-10-11 23:44:16 +08:00
|
|
|
# We only allow `backend_legal_ops` to be specified for the `"torch"`
|
|
|
|
# output type because the other output types actually invoke their
|
2023-02-02 21:29:47 +08:00
|
|
|
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
|
2022-10-11 23:44:16 +08:00
|
|
|
# very specific requirements about the ops which are legal.
|
|
|
|
# See `BACKEND_LEGAL_OPS` for more details.
|
|
|
|
if backend_legal_ops is not None:
|
|
|
|
if output_type != OutputType.TORCH:
|
2024-04-28 05:16:31 +08:00
|
|
|
raise Exception(
|
|
|
|
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
|
|
|
)
|
2022-10-11 23:44:16 +08:00
|
|
|
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
|
|
|
else:
|
|
|
|
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
|
|
|
|
2023-07-13 21:07:54 +08:00
|
|
|
if use_make_fx:
|
2024-04-28 05:16:31 +08:00
|
|
|
args = example_args._get_for_tracing(
|
|
|
|
use_tracing=True, ignore_traced_shapes=True
|
|
|
|
)["forward"]
|
|
|
|
model = make_fx(model, decomposition_table=_get_decomposition_table())(*args)
|
2023-07-13 21:07:54 +08:00
|
|
|
|
2022-11-30 11:19:09 +08:00
|
|
|
# For FX-based models, automatically strip overloads.
|
|
|
|
if isinstance(model, torch.fx.GraphModule):
|
|
|
|
strip_overloads(model)
|
|
|
|
|
2022-11-15 18:47:19 +08:00
|
|
|
# Get the model as JIT IR (TorchScript) for import.
|
2024-02-07 11:07:59 +08:00
|
|
|
# TODO: Longer-term, we probably need to split `torchscript.compile`.
|
2022-11-15 18:47:19 +08:00
|
|
|
# There should be an "acquisition" step that does
|
|
|
|
# tracing/scripting/importing from FX/using torchdynamo.export/etc.
|
|
|
|
# + any lowering to the backend contract. Then there should be a
|
|
|
|
# "backend lowering" step that does the actual lowering to each
|
|
|
|
# backend. This separation should be visible at the Python API level, and
|
2024-02-07 11:07:59 +08:00
|
|
|
# we can implement a deliberately simplified API like `torchscript.compile`
|
2022-11-15 18:47:19 +08:00
|
|
|
# on top of those building blocks.
|
2022-12-23 00:39:55 +08:00
|
|
|
if isinstance(model, torch.jit.ScriptModule):
|
2022-11-15 18:47:19 +08:00
|
|
|
# If the user already converted the model to JIT IR themselves, just
|
|
|
|
# do some basic error checking, but take the model as-is.
|
|
|
|
for method_name in example_args._get_methods():
|
|
|
|
if not hasattr(model, method_name):
|
|
|
|
raise Exception(
|
|
|
|
f"Model does not have exported method '{method_name}', "
|
|
|
|
f"requested in `example_args`. Consider adding "
|
2024-04-28 05:16:31 +08:00
|
|
|
f"`@torch.jit.export` to the method definition."
|
|
|
|
)
|
2022-11-15 18:47:19 +08:00
|
|
|
scripted = model
|
|
|
|
elif use_tracing:
|
2022-11-09 19:15:54 +08:00
|
|
|
scripted = torch.jit.trace_module(
|
2024-04-28 05:16:31 +08:00
|
|
|
model, example_args._get_for_tracing(use_tracing, ignore_traced_shapes)
|
2022-11-09 19:15:54 +08:00
|
|
|
)
|
2022-05-03 16:43:13 +08:00
|
|
|
else:
|
2022-11-09 19:15:54 +08:00
|
|
|
# Make sure that all the methods that the user requested get scripted.
|
|
|
|
# By default, PyTorch only scripts the `forward` method and transitive
|
|
|
|
# callees.
|
|
|
|
for method_name in example_args._get_methods():
|
|
|
|
torch.jit.export(getattr(model, method_name).__func__)
|
2022-05-03 16:43:13 +08:00
|
|
|
scripted = torch.jit.script(model)
|
2022-04-20 08:30:09 +08:00
|
|
|
class_annotator = ClassAnnotator()
|
|
|
|
class_annotator.exportNone(scripted._c._type())
|
2022-11-09 19:15:54 +08:00
|
|
|
for method_name, example_args in example_args._get_for_annotation().items():
|
|
|
|
class_annotator.exportPath(scripted._c._type(), [method_name])
|
|
|
|
annotation = [None] # `None` is always the annotation for "self".
|
|
|
|
for arg in example_args:
|
|
|
|
annotation.append((arg.shape, arg.dtype, True))
|
2024-04-28 05:16:31 +08:00
|
|
|
class_annotator.annotateArgs(scripted._c._type(), [method_name], annotation)
|
2022-04-20 08:30:09 +08:00
|
|
|
|
|
|
|
mb = ModuleBuilder()
|
2022-08-04 07:30:57 +08:00
|
|
|
import_options = ImportOptions()
|
|
|
|
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
|
2022-08-19 08:01:54 +08:00
|
|
|
try:
|
|
|
|
original_stderr = sys.stderr
|
|
|
|
sys.stderr = StringIO()
|
|
|
|
# Import the TorchScript module to MLIR
|
|
|
|
mb.import_module(scripted._c, class_annotator, import_options)
|
|
|
|
except Exception as e:
|
2024-04-28 05:16:31 +08:00
|
|
|
raise Exception(
|
|
|
|
f"""
|
2022-08-19 08:01:54 +08:00
|
|
|
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|
|
|
### Importer C++ Exception:
|
|
|
|
{e}
|
|
|
|
### Importer Diagnostics:
|
|
|
|
{sys.stderr.getvalue()}
|
2024-04-28 05:16:31 +08:00
|
|
|
"""
|
|
|
|
) from None
|
2022-08-19 08:01:54 +08:00
|
|
|
finally:
|
|
|
|
sys.stderr = original_stderr
|
2024-05-27 08:01:07 +08:00
|
|
|
|
|
|
|
if verbose:
|
|
|
|
print("\n====================")
|
|
|
|
print("TorchScript RAW IR")
|
|
|
|
print(mb.module)
|
|
|
|
|
2022-05-19 18:28:29 +08:00
|
|
|
if output_type == OutputType.RAW:
|
|
|
|
return mb.module
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
option_string = (
|
|
|
|
"{backend-legal-ops="
|
|
|
|
+ ",".join(backend_legal_ops)
|
|
|
|
+ " extra-library="
|
|
|
|
+ extra_library_file_name
|
|
|
|
+ "}"
|
|
|
|
)
|
2022-08-19 08:01:54 +08:00
|
|
|
run_pipeline_with_repro_report(
|
|
|
|
mb.module,
|
2022-11-10 18:39:28 +08:00
|
|
|
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
2022-08-19 08:01:54 +08:00
|
|
|
"Lowering TorchScript IR -> Torch Backend IR",
|
2023-12-21 05:08:21 +08:00
|
|
|
enable_ir_printing=enable_ir_printing,
|
2022-08-19 08:01:54 +08:00
|
|
|
)
|
2022-04-20 08:30:09 +08:00
|
|
|
|
2024-04-27 18:27:37 +08:00
|
|
|
return lower_mlir_module(verbose, output_type, mb.module)
|