torch-mlir/python/torch_mlir/__init__.py

171 lines
6.4 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Sequence, Union, List
from enum import Enum
import torch
from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
class OutputType(Enum):
"""The kind of output that `torch_mlir.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process.
"""
# This output type consists of `torch` dialect ops that have been converted
# maximally to value semantics, decomposed, and shapes have been inferred.
TORCH = 0
# This output type consists of `tosa` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = 1
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = 2
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = 3
class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
This class can be constructed explicitly from a shape and dtype:
```python
placeholder = TensorPlaceholder([3, 4], torch.float32)
```
This class can also be constructed from a `torch.Tensor` which is already
known to be a valid input to the function. In this case, a set of
dynamic axes are allowed to be specified.
```python
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
```
"""
def __init__(self, shape: List[int], dtype: torch.dtype):
"""Create a tensor with shape `shape` and dtype `dtype`.
Args:
shape: The shape of the tensor. A size of `-1` indicates that the
dimension has an unknown size.
dtype: The dtype of the tensor.
"""
self.shape = shape
self.dtype = dtype
@staticmethod
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
"""Create a tensor placeholder that is like the given tensor.
Args:
tensor: The tensor to create a placeholder for.
dynamic_axes: A list of dynamic axes. If specified, the compiled
module will allow those axes to be any size at runtime.
"""
if dynamic_axes is None:
dynamic_axes = []
shape = []
for i, dim in enumerate(tensor.shape):
if i in dynamic_axes:
shape.append(-1)
else:
shape.append(dim)
return TensorPlaceholder(shape, tensor.dtype)
_example_arg = Union[TensorPlaceholder, torch.Tensor]
def compile(model: torch.nn.Module,
example_args: Union[_example_arg, Sequence[_example_arg]],
output_type: OutputType = OutputType.TORCH,
use_tracing=False):
"""Convert a PyTorch model to MLIR.
Args:
model: The PyTorch model to convert.
example_args: A list of example arguments to use when inferring the
shapes of the arguments to `forward` method of the model.
A single tensor is treated as a list of a single tensor.
A TensorPlaceholder object is also allowed in the place of any
Tensor.
output_type: The kind of output to produce. See `OutputType` for more
details.
use_tracing: If True, use `torch.jit.trace` to convert the model to
JIT IR rather than `torch.jit.script`.
Returns:
An MLIR module that contains the converted model in the specified
output type.
"""
# Special case -- many models have just one input, so canonicalize a single
# tensor to a list of a single tensor to make the API more ergonomic.
if isinstance(example_args, (torch.Tensor, TensorPlaceholder)):
example_args = (example_args,)
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
# `torch.jit.trace_module` for API inspiration.
if use_tracing:
scripted = torch.jit.trace(model, tuple(example_args))
else:
scripted = torch.jit.script(model)
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
arg_placeholders = []
for arg in example_args:
if isinstance(arg, TensorPlaceholder):
arg_placeholders.append(arg)
else:
assert isinstance(arg, torch.Tensor)
arg_placeholders.append(TensorPlaceholder.like(arg))
class_annotator = ClassAnnotator()
forward_annotation = [None]
for arg in arg_placeholders:
# Assume that all tensors have value semantics for now.
forward_annotation.append((arg.shape, arg.dtype, True))
class_annotator.exportNone(scripted._c._type())
class_annotator.exportPath(scripted._c._type(), ["forward"])
class_annotator.annotateArgs(
scripted._c._type(), ["forward"], forward_annotation)
mb = ModuleBuilder()
mb.import_module(scripted._c, class_annotator)
if output_type == OutputType.RAW:
return mb.module
run_pipeline_with_repro_report(mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript IR -> Torch Backend IR")
if output_type == OutputType.TORCH:
return mb.module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-tosa-backend-pipeline",
"Lowering Torch Backend IR -> TOSA Backend IR")
return mb.module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
return mb.module
raise Exception(f"Unknown OutputType: {output_type}")