mirror of https://github.com/llvm/torch-mlir
171 lines
6.3 KiB
Python
171 lines
6.3 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
from typing import Union, List
|
|
from enum import Enum
|
|
|
|
import torch
|
|
|
|
from torch_mlir.passmanager import PassManager
|
|
from .compiler_utils import run_pipeline_with_repro_report
|
|
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
|
|
|
|
|
class OutputType(Enum):
|
|
"""The kind of output that `torch_mlir.compile` can produce.
|
|
|
|
In MLIR terminology, this describes the mix of dialects that will be
|
|
produced by the conversion process.
|
|
"""
|
|
# This output type consists of `torch` dialect ops that have been converted
|
|
# maximally to value semantics, decomposed, and shapes have been inferred.
|
|
TORCH = 0
|
|
# This output type consists of `tosa` dialect ops. It can be thought of
|
|
# as taking the `TORCH` output type and lowering it to TOSA.
|
|
TOSA = 1
|
|
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
|
|
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
|
|
# as taking the `TORCH` output type and lowering it so that tensor
|
|
# computations are done with `linalg`-on-tensors ops.
|
|
LINALG_ON_TENSORS = 2
|
|
# Raw output of the JIT IR importer. This is not expected to be useful
|
|
# for end-users, but can be convenient for development or reporting bugs.
|
|
RAW = 3
|
|
|
|
|
|
class TensorPlaceholder:
|
|
"""A class that represents a formal parameter of a given shape and dtype.
|
|
|
|
This class can be constructed explicitly from a shape and dtype:
|
|
```python
|
|
placeholder = TensorPlaceholder([3, 4], torch.float32)
|
|
```
|
|
|
|
This class can also be constructed from a `torch.Tensor` which is already
|
|
known to be a valid input to the function. In this case, a set of
|
|
dynamic axes are allowed to be specified.
|
|
```python
|
|
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
|
|
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
|
|
```
|
|
"""
|
|
|
|
def __init__(self, shape: List[int], dtype: torch.dtype):
|
|
"""Create a tensor with shape `shape` and dtype `dtype`.
|
|
|
|
Args:
|
|
shape: The shape of the tensor. A size of `-1` indicates that the
|
|
dimension has an unknown size.
|
|
dtype: The dtype of the tensor.
|
|
"""
|
|
self.shape = shape
|
|
self.dtype = dtype
|
|
|
|
@staticmethod
|
|
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
|
|
"""Create a tensor placeholder that is like the given tensor.
|
|
|
|
Args:
|
|
tensor: The tensor to create a placeholder for.
|
|
dynamic_axes: A list of dynamic axes. If specified, the compiled
|
|
module will allow those axes to be any size at runtime.
|
|
"""
|
|
if dynamic_axes is None:
|
|
dynamic_axes = []
|
|
shape = []
|
|
for i, dim in enumerate(tensor.shape):
|
|
if i in dynamic_axes:
|
|
shape.append(-1)
|
|
else:
|
|
shape.append(dim)
|
|
return TensorPlaceholder(shape, tensor.dtype)
|
|
|
|
|
|
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
|
|
|
|
|
def compile(model: torch.nn.Module,
|
|
example_args: Union[_example_arg, List[_example_arg]],
|
|
output_type: OutputType = OutputType.TORCH,
|
|
use_tracing=False):
|
|
"""Convert a PyTorch model to MLIR.
|
|
|
|
Args:
|
|
model: The PyTorch model to convert.
|
|
example_args: A list of example arguments to use when inferring the
|
|
shapes of the arguments to `forward` method of the model.
|
|
A single tensor is treated as a list of a single tensor.
|
|
A TensorPlaceholder object is also allowed in the place of any
|
|
Tensor.
|
|
output_type: The kind of output to produce. See `OutputType` for more
|
|
details.
|
|
use_tracing: If True, use `torch.jit.trace` to convert the model to
|
|
JIT IR rather than `torch.jit.script`.
|
|
|
|
Returns:
|
|
An MLIR module that contains the converted model in the specified
|
|
output type.
|
|
"""
|
|
|
|
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
|
|
# `torch.jit.trace_module` for API inspiration.
|
|
if use_tracing:
|
|
scripted = torch.jit.trace(model, tuple(example_args))
|
|
else:
|
|
scripted = torch.jit.script(model)
|
|
|
|
# Special case -- many models have just one input, so canonicalize a single
|
|
# tensor to a list of a single tensor to make the API more ergonomic.
|
|
if not isinstance(example_args, list):
|
|
example_args = [example_args]
|
|
|
|
# Convert all concrete inputs to TensorPlaceholder's, for consistency.
|
|
arg_placeholders = []
|
|
for arg in example_args:
|
|
if isinstance(arg, TensorPlaceholder):
|
|
arg_placeholders.append(arg)
|
|
else:
|
|
assert isinstance(arg, torch.Tensor)
|
|
arg_placeholders.append(TensorPlaceholder.like(arg))
|
|
|
|
class_annotator = ClassAnnotator()
|
|
forward_annotation = [None]
|
|
for arg in arg_placeholders:
|
|
# Assume that all tensors have value semantics for now.
|
|
forward_annotation.append((arg.shape, arg.dtype, True))
|
|
class_annotator.exportNone(scripted._c._type())
|
|
class_annotator.exportPath(scripted._c._type(), ["forward"])
|
|
class_annotator.annotateArgs(
|
|
scripted._c._type(), ["forward"], forward_annotation)
|
|
|
|
mb = ModuleBuilder()
|
|
mb.import_module(scripted._c, class_annotator)
|
|
|
|
if output_type == OutputType.RAW:
|
|
return mb.module
|
|
|
|
run_pipeline_with_repro_report(mb.module,
|
|
"torchscript-module-to-torch-backend-pipeline",
|
|
"Lowering TorchScript IR -> Torch Backend IR")
|
|
|
|
if output_type == OutputType.TORCH:
|
|
return mb.module
|
|
|
|
if output_type == OutputType.TOSA:
|
|
run_pipeline_with_repro_report(
|
|
mb.module,
|
|
"torch-backend-to-tosa-backend-pipeline",
|
|
"Lowering Torch Backend IR -> TOSA Backend IR")
|
|
return mb.module
|
|
|
|
if output_type == OutputType.LINALG_ON_TENSORS:
|
|
run_pipeline_with_repro_report(
|
|
mb.module,
|
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
|
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
|
return mb.module
|
|
|
|
raise Exception(f"Unknown OutputType: {output_type}")
|