torch-mlir/python/torch_mlir/__init__.py

229 lines
8.4 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Sequence, Union, List
from enum import Enum
import torch
from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
class OutputType(Enum):
"""The kind of output that `torch_mlir.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process.
In user-facing API's, this type can always be passed interchangeably with an
appropriate string specifying the output type. The allowed strings are
the set of enum vales, allowed to be case insensitive and with `-` allowed
in place of `_`. The `OutputType.get` static method can be used to convert
from a string to an `OutputType` instance.
"""
# This output type consists of `torch` dialect ops that have been converted
# maximally to value semantics, decomposed, and shapes have been inferred.
TORCH = 0
# This output type consists of `tosa` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = 1
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = 2
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = 3
# This output type consists of `mhlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to MHLO.
MHLO = 4
@staticmethod
def get(spec: Union[str, "OutputType"]) -> "OutputType":
"""Gets an OutputType from allowed way to specify one.
Args:
spec: An OutputType instance or the case-insensitive name of one of the
enum values.
Returns:
An OutputType instance.
"""
if isinstance(spec, OutputType):
return spec
spec = spec.upper().replace("-", "_")
if spec not in OutputType.__members__:
raise ValueError(f"For output_type= argument, expected one of: "
f"{', '.join(OutputType.__members__.keys())}")
return OutputType[spec]
class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
This class can be constructed explicitly from a shape and dtype:
```python
placeholder = TensorPlaceholder([3, 4], torch.float32)
```
This class can also be constructed from a `torch.Tensor` which is already
known to be a valid input to the function. In this case, a set of
dynamic axes are allowed to be specified.
```python
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
```
"""
def __init__(self, shape: List[int], dtype: torch.dtype):
"""Create a tensor with shape `shape` and dtype `dtype`.
Args:
shape: The shape of the tensor. A size of `-1` indicates that the
dimension has an unknown size.
dtype: The dtype of the tensor.
"""
self.shape = shape
self.dtype = dtype
@staticmethod
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
"""Create a tensor placeholder that is like the given tensor.
Args:
tensor: The tensor to create a placeholder for.
dynamic_axes: A list of dynamic axes. If specified, the compiled
module will allow those axes to be any size at runtime.
"""
if dynamic_axes is None:
dynamic_axes = []
shape = []
for i, dim in enumerate(tensor.shape):
if i in dynamic_axes:
shape.append(-1)
else:
shape.append(dim)
return TensorPlaceholder(shape, tensor.dtype)
_example_arg = Union[TensorPlaceholder, torch.Tensor]
def compile(model: torch.nn.Module,
example_args: Union[_example_arg, Sequence[_example_arg]],
output_type: Union[str, "OutputType"] = OutputType.TORCH,
use_tracing: bool = False,
verbose: bool = False):
"""Convert a PyTorch model to MLIR.
Args:
model: The PyTorch model to convert.
example_args: A list of example arguments to use when inferring the
shapes of the arguments to `forward` method of the model.
A single tensor is treated as a list of a single tensor.
A TensorPlaceholder object is also allowed in the place of any
Tensor.
output_type: The kind of output to produce. See `OutputType` for more
details.
use_tracing: If True, use `torch.jit.trace` to convert the model to
JIT IR rather than `torch.jit.script`.
Returns:
An MLIR module that contains the converted model in the specified
output type.
"""
output_type = OutputType.get(output_type)
# Special case -- many models have just one input, so canonicalize a single
# tensor to a list of a single tensor to make the API more ergonomic.
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
example_args = (example_args,)
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
# `torch.jit.trace_module` for API inspiration.
if use_tracing:
scripted = torch.jit.trace(model, tuple(example_args))
else:
scripted = torch.jit.script(model)
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
arg_placeholders = []
for arg in example_args:
if isinstance(arg, TensorPlaceholder):
arg_placeholders.append(arg)
else:
assert isinstance(arg, torch.Tensor)
arg_placeholders.append(TensorPlaceholder.like(arg))
class_annotator = ClassAnnotator()
forward_annotation = [None]
for arg in arg_placeholders:
# Assume that all tensors have value semantics for now.
forward_annotation.append((arg.shape, arg.dtype, True))
class_annotator.exportNone(scripted._c._type())
class_annotator.exportPath(scripted._c._type(), ["forward"])
class_annotator.annotateArgs(
scripted._c._type(), ["forward"], forward_annotation)
mb = ModuleBuilder()
mb.import_module(scripted._c, class_annotator)
if output_type == OutputType.RAW:
return mb.module
run_pipeline_with_repro_report(mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript IR -> Torch Backend IR")
if verbose:
print("\n====================")
print("Torch Backend IR")
print(mb.module)
if output_type == OutputType.TORCH:
return mb.module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-tosa-backend-pipeline",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
print("TOSA Backend IR")
print(mb.module)
return mb.module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
if verbose:
print("\n====================")
print("LINALG Backend IR")
print(mb.module)
return mb.module
elif output_type == OutputType.MHLO:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-mhlo-backend-pipeline",
"Lowering Torch Backend IR -> MHLO Backend IR")
if verbose:
print("\n====================")
print("MHLO Backend IR")
print(mb.module)
return mb.module
raise Exception(f"Unknown OutputType: {output_type}")