mirror of https://github.com/llvm/torch-mlir
Implement the lazytensor package (#331)
Implement the `lazytensor` python package for converting lazy computations captured by the Lazy Tensor Core into MLIR. This PR also fixes a few things with `torchfx` and its examplepull/337/head
parent
2b99c8b990
commit
b59f2cb673
31
README.md
31
README.md
|
@ -104,13 +104,40 @@ jupyter notebook
|
||||||
|
|
||||||
### TorchFX
|
### TorchFX
|
||||||
|
|
||||||
TODO
|
The `examples` folder includes the Python package `torchfx`, which is a functional prototype of a TorchFX to MLIR pipeline. The main entry point into the `torchfx` package is the `torchfx.builder` module, which includes a function for converting the output of a TorchFX trace into MLIR. Currently, the number of PyTorch operations supported is very limited, but will be expanded in the future.
|
||||||
|
|
||||||
|
#### Example usage of `torchfx`
|
||||||
|
|
||||||
|
The `examples` folder includes scripts `torchfx_*.py` showing how to use the TorchFX to MLIR pipeline. In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and add `/path/to/torch-mlir/examples` to your `PYTHONPATH`.
|
||||||
|
|
||||||
|
Then, run
|
||||||
|
|
||||||
|
```
|
||||||
|
python torchfx_example_name.py
|
||||||
|
```
|
||||||
|
|
||||||
|
replacing `torchfx_example_name.py` with the actual `torchfx` example you want to run.
|
||||||
|
|
||||||
|
|
||||||
### Lazy Tensor Core
|
### Lazy Tensor Core
|
||||||
|
|
||||||
TODO
|
The `examples` folder includes the Python package `lazytensor`, which implements a Lazy Tensor Core (LTC) to MLIR pipeline. The main entry point into the `lazytensor` package is the `lazytensor.builder`, which includes the function `build_module` that takes a computation captured and converted to TorchScript IR by LTC, and converts it to MLIR.
|
||||||
|
|
||||||
|
#### Example usage of `lazytensor`
|
||||||
|
|
||||||
|
The `examples` folder includes scripts `lazytensor_*.py` showing how to use the Lazy Tensor to MLIR pipeline. The examples depend on the Lazy Tensor Core (LTC) of PyTorch. For information on how to obtain LTC, see [here](https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md).
|
||||||
|
|
||||||
|
In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and also add the following to your `PYTHONPATH`:
|
||||||
|
1. `/path/to/torch-mlir/examples`
|
||||||
|
2. `/path/to/pytorch/lazy_tensor_core`
|
||||||
|
|
||||||
|
Then, run
|
||||||
|
|
||||||
|
```
|
||||||
|
python lazytensor_example_name.py
|
||||||
|
```
|
||||||
|
|
||||||
|
replacing `lazytensor_example_name.py` with the actual `lazytensor` example you want to run.
|
||||||
|
|
||||||
## Repository Layout
|
## Repository Layout
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
"""
|
||||||
|
Translator from torch.jit.ScriptFunction to MLIR.
|
||||||
|
|
||||||
|
The following defines a function that take a torch.jit.ScriptFunction
|
||||||
|
and converts it into an MLIR module.
|
||||||
|
|
||||||
|
The expected use for this module is to use the function
|
||||||
|
`build_module(jit_function: torch.jit.ScriptFunction
|
||||||
|
annotation: Annotation) -> ir.Module`
|
||||||
|
to convert the TorchScript function into MLIR using the `torch`
|
||||||
|
dialect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.jit import ScriptFunction
|
||||||
|
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||||
|
from torch_mlir.dialects.builtin import FuncOp
|
||||||
|
from torch_mlir import ir
|
||||||
|
|
||||||
|
from utils.annotator import AnnotationConverter as ac
|
||||||
|
from utils.annotator import Annotation
|
||||||
|
|
||||||
|
def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
||||||
|
with module.context:
|
||||||
|
name_attr = ir.StringAttr.get(name)
|
||||||
|
for op in module.body.operations:
|
||||||
|
if isinstance(op, FuncOp) and op.name == name_attr:
|
||||||
|
return op
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build_module(jit_function: ScriptFunction,
|
||||||
|
annotation: Annotation) -> ir.Module:
|
||||||
|
"""
|
||||||
|
Translate input function into an MLIR module in the `torch` dialect.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
jit_function: ScriptFunction
|
||||||
|
Function in TorchScript IR to turn into MLIR.
|
||||||
|
annotation: Annotation
|
||||||
|
Annotation object representing the types of
|
||||||
|
the operands of `jit_function`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ir.Module
|
||||||
|
Translation of the input module into an MLIR module
|
||||||
|
"""
|
||||||
|
mb = ModuleBuilder()
|
||||||
|
mb.import_function(jit_function)
|
||||||
|
|
||||||
|
func_op = _get_func_op_with_name(mb.module, jit_function.name)
|
||||||
|
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder'
|
||||||
|
|
||||||
|
arg_attrs = ac.to_mlir_array_attr(annotation, mb.context)
|
||||||
|
func_op.attributes['arg_attrs'] = arg_attrs
|
||||||
|
|
||||||
|
return mb.module
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
"""
|
||||||
|
Example of taking a Lazy Tensor computation and compiling it using torch-mlir.
|
||||||
|
|
||||||
|
This example depends on the Lazy Tensor Core (LTC) of PyTorch. For information
|
||||||
|
on how to obtain LTC, see here:
|
||||||
|
https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md
|
||||||
|
|
||||||
|
To run the example, make sure the following are in your PYTHONPATH:
|
||||||
|
1. /path/to/torch-mlir/examples
|
||||||
|
2. /path/to/pytorch/lazy_tensor_core
|
||||||
|
3. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
|
||||||
|
|
||||||
|
then, simply call `python lazytensor_tanh.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import lazy_tensor_core as ltc
|
||||||
|
from torch._C import CompilationUnit
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
|
||||||
|
import RefBackendLinalgOnTensorsBackend
|
||||||
|
from torch_mlir.passmanager import PassManager
|
||||||
|
|
||||||
|
from utils.annotator import Annotation
|
||||||
|
from utils.torch_mlir_types import TorchTensorType
|
||||||
|
from lazytensor.builder import build_module
|
||||||
|
|
||||||
|
ltc._LAZYC._ltc_init_ts_backend()
|
||||||
|
device = 'lazy'
|
||||||
|
dtype = torch.float32
|
||||||
|
shape = (2, 3)
|
||||||
|
|
||||||
|
x = torch.randn(shape, device=device, dtype=dtype)
|
||||||
|
y = torch.randn(shape, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def computation(x, y):
|
||||||
|
return y * x.tanh()
|
||||||
|
|
||||||
|
# Capture lazy computation and convert to TorchScript IR
|
||||||
|
graph_str = ltc._LAZYC._get_ltc_tensors_backend([computation(x, y)])
|
||||||
|
print("LAZY GRAPH")
|
||||||
|
print(graph_str)
|
||||||
|
graph = torch._C.parse_ir(graph_str)
|
||||||
|
|
||||||
|
# Create a torch.jit.ScriptFunction out of the graph
|
||||||
|
cu = CompilationUnit()
|
||||||
|
func_name = 'my_method'
|
||||||
|
script_function = cu.create_function(func_name, graph)
|
||||||
|
|
||||||
|
# `build_module` takes he torch.jit.ScriptFunction and the
|
||||||
|
# annotation on the operand types, and outputs an `ir.Module`
|
||||||
|
# with a single function representing the ScriptFunction in
|
||||||
|
# the torch MLIR dialect
|
||||||
|
func_annotation = Annotation([TorchTensorType(shape=shape, dtype=torch.float),
|
||||||
|
TorchTensorType(shape=shape, dtype=torch.float)])
|
||||||
|
mlir_module = build_module(script_function, func_annotation)
|
||||||
|
|
||||||
|
print("MLIR")
|
||||||
|
mlir_module.dump()
|
||||||
|
|
||||||
|
# Compile the torch MLIR and execute the compiled program
|
||||||
|
with mlir_module.context:
|
||||||
|
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
|
||||||
|
pm.run(mlir_module)
|
||||||
|
|
||||||
|
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
|
||||||
|
print(mlir_module)
|
||||||
|
|
||||||
|
backend = RefBackendLinalgOnTensorsBackend()
|
||||||
|
compiled = backend.compile(mlir_module)
|
||||||
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
|
print("\n\nRunning Example Calculation")
|
||||||
|
print("Compiled result:")
|
||||||
|
print(jit_module.my_method(x.cpu().numpy(), y.cpu().numpy()))
|
||||||
|
print("Expected result:")
|
||||||
|
print(computation(x, y))
|
|
@ -1,20 +0,0 @@
|
||||||
# -*- Python -*-
|
|
||||||
# This file is licensed under a pytorch-style license
|
|
||||||
# See frontends/pytorch/LICENSE for license information.
|
|
||||||
#
|
|
||||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
|
||||||
|
|
||||||
from typing import Iterable, Union
|
|
||||||
from torch.fx import GraphModule
|
|
||||||
from .torch_mlir_types import TorchTensorType, PythonType
|
|
||||||
|
|
||||||
def annotate_forward_args(module: GraphModule,
|
|
||||||
types: Iterable[Union[TorchTensorType, type]]
|
|
||||||
) -> GraphModule:
|
|
||||||
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
|
|
||||||
for operand, type_ in zip(operands, types):
|
|
||||||
if isinstance(type_, type):
|
|
||||||
type_ = PythonType(type_)
|
|
||||||
operand.update_kwarg('torch_mlir_type', type_)
|
|
||||||
|
|
||||||
return module
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
"""
|
"""
|
||||||
Translator from Torch.FX to MLIR.
|
Translator from Torch.FX to MLIR.
|
||||||
|
|
||||||
|
@ -9,9 +12,6 @@ The expected use for this module is to use the function
|
||||||
`build_module(py_module: torch.fx.GraphModule) -> ir.Module`
|
`build_module(py_module: torch.fx.GraphModule) -> ir.Module`
|
||||||
to convert the output from the tracer into MLIR using the `torch`
|
to convert the output from the tracer into MLIR using the `torch`
|
||||||
dialect.
|
dialect.
|
||||||
|
|
||||||
This file is licensed under a pytorch-style license
|
|
||||||
See frontends/pytorch/LICENSE for license information.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme
|
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme
|
||||||
|
@ -27,7 +27,8 @@ from torch_mlir.dialects import builtin, std
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.experimental.fx_acc import acc_ops
|
from torch.fx.experimental.fx_acc import acc_ops
|
||||||
|
|
||||||
from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType
|
from utils.torch_mlir_types import TorchTensorType, PythonType, \
|
||||||
|
TorchNnModuleType
|
||||||
|
|
||||||
Environment = MutableMapping[torch.fx.Node, ir.Value]
|
Environment = MutableMapping[torch.fx.Node, ir.Value]
|
||||||
|
|
||||||
|
@ -388,7 +389,7 @@ an argument named `input`'
|
||||||
|
|
||||||
|
|
||||||
@_add_handler(ACC_OP_HANDLERS, acc_ops.add)
|
@_add_handler(ACC_OP_HANDLERS, acc_ops.add)
|
||||||
def _add_handler(func_builder: _ForwardFunctionBuilder,
|
def _add_tensor_handler(func_builder: _ForwardFunctionBuilder,
|
||||||
args: Mapping[str, ir.Value]) -> ir.OpResult:
|
args: Mapping[str, ir.Value]) -> ir.OpResult:
|
||||||
input_arg = args.get('input')
|
input_arg = args.get('input')
|
||||||
other_arg = args.get('other')
|
other_arg = args.get('other')
|
||||||
|
@ -396,10 +397,10 @@ def _add_handler(func_builder: _ForwardFunctionBuilder,
|
||||||
'A call to this handler must include an argument named `input` \
|
'A call to this handler must include an argument named `input` \
|
||||||
and an argument named `other`'
|
and an argument named `other`'
|
||||||
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
||||||
int_type = PythonType(int).to_mlir(func_builder.context)
|
torch_int_type = PythonType(int).to_mlir(func_builder.context)
|
||||||
attr_type = ir.Type.parse('i64', func_builder.context)
|
int_type = ir.Type.parse("i64", context=func_builder.context)
|
||||||
int_attr = ir.IntegerAttr.get(attr_type, 1)
|
int_attr = ir.IntegerAttr.get(int_type, 1)
|
||||||
alpha_arg = torch_d.ConstantIntOp(int_type,
|
alpha_arg = torch_d.ConstantIntOp(torch_int_type,
|
||||||
int_attr,
|
int_attr,
|
||||||
loc=func_builder.loc,
|
loc=func_builder.loc,
|
||||||
ip=func_builder.func_ip).result
|
ip=func_builder.func_ip).result
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
# -*- Python -*-
|
|
||||||
# This file is licensed under a pytorch-style license
|
|
||||||
# See frontends/pytorch/LICENSE for license information.
|
|
||||||
|
|
||||||
# From the torch-mlir root, run with:
|
|
||||||
# `python -m examples.torchfx.examples.example_add_tanh_sigmoid`
|
|
||||||
# (after setting up python environment with write_env_file.sh)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.fx.experimental.fx_acc import acc_tracer
|
|
||||||
import torch_mlir
|
|
||||||
from torch_mlir.dialects.torch import register_dialect
|
|
||||||
from torch_mlir.passmanager import PassManager
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
|
||||||
|
|
||||||
from ..builder import build_module
|
|
||||||
from ..annotator import annotate_forward_args
|
|
||||||
from ..torch_mlir_types import TorchTensorType
|
|
||||||
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
# TODO: Debug issue with RefBackend
|
|
||||||
#return torch.tanh(x) + torch.sigmoid(y)
|
|
||||||
return torch.tanh(x)
|
|
||||||
|
|
||||||
|
|
||||||
module = MyModule()
|
|
||||||
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
|
|
||||||
torch.Tensor(2,2)])
|
|
||||||
|
|
||||||
print("TRACE")
|
|
||||||
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
|
|
||||||
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
|
|
||||||
print(traced_module.graph)
|
|
||||||
torch_mlir_module = build_module(traced_module)
|
|
||||||
|
|
||||||
print("\n\nTORCH MLIR")
|
|
||||||
torch_mlir_module.dump()
|
|
||||||
print(torch_mlir_module.operation.verify())
|
|
||||||
|
|
||||||
with torch_mlir_module.context:
|
|
||||||
pm = PassManager.parse('torchscript-to-linalg-on-tensors-backend-pipeline')
|
|
||||||
pm.run(torch_mlir_module)
|
|
||||||
|
|
||||||
print("\n\nLOWERED MLIR")
|
|
||||||
torch_mlir_module.dump()
|
|
||||||
|
|
||||||
backend = RefBackendLinalgOnTensorsBackend()
|
|
||||||
compiled = backend.compile(torch_mlir_module)
|
|
||||||
jit_module = backend.load(compiled)
|
|
||||||
|
|
||||||
print("\n\nRunning Forward Function")
|
|
||||||
t = torch.rand((2, 2), dtype=torch.float)
|
|
||||||
print("Compiled result:\n", jit_module.forward(t.numpy(), t.numpy()))
|
|
||||||
print("\nExpected result:\n", module.forward(t, t))
|
|
|
@ -1,6 +1,7 @@
|
||||||
# -*- Python -*-
|
# -*- Python -*-
|
||||||
# This file is licensed under a pytorch-style license
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
# See frontends/pytorch/LICENSE for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
#
|
#
|
||||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
"""
|
||||||
|
Example of taking a moduled traced by TorchFX and compiling it using torch-mlir.
|
||||||
|
|
||||||
|
To run the example, make sure the following are in your PYTHONPATH:
|
||||||
|
1. /path/to/torch-mlir/examples
|
||||||
|
2. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
|
||||||
|
|
||||||
|
then, simply call `python torchfx_add_tanh_sigmoid.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.fx.experimental.fx_acc import acc_tracer
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
|
||||||
|
import RefBackendLinalgOnTensorsBackend
|
||||||
|
from torch_mlir.passmanager import PassManager
|
||||||
|
|
||||||
|
from torchfx.builder import build_module
|
||||||
|
from utils.annotator import annotate_forward_args
|
||||||
|
from utils.torch_mlir_types import TorchTensorType
|
||||||
|
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
# TODO: Debug issue with RefBackend
|
||||||
|
#return torch.tanh(x) + torch.sigmoid(y)
|
||||||
|
return torch.tanh(x)
|
||||||
|
|
||||||
|
|
||||||
|
module = MyModule()
|
||||||
|
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
|
||||||
|
torch.Tensor(2,2)])
|
||||||
|
|
||||||
|
print("TRACE")
|
||||||
|
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
|
||||||
|
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
|
||||||
|
print(traced_module.graph)
|
||||||
|
mlir_module = build_module(traced_module)
|
||||||
|
|
||||||
|
print("\n\nTORCH MLIR")
|
||||||
|
mlir_module.dump()
|
||||||
|
print(mlir_module.operation.verify())
|
||||||
|
|
||||||
|
with mlir_module.context:
|
||||||
|
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
|
||||||
|
pm.run(mlir_module)
|
||||||
|
|
||||||
|
print("\n\nLOWERED MLIR")
|
||||||
|
mlir_module.dump()
|
||||||
|
|
||||||
|
backend = RefBackendLinalgOnTensorsBackend()
|
||||||
|
compiled = backend.compile(mlir_module)
|
||||||
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
|
print("\n\nRunning Forward Function")
|
||||||
|
np_t = np.random.rand(2, 2).astype(dtype=np.float32)
|
||||||
|
t = torch.tensor(np_t, dtype=torch.float)
|
||||||
|
print("Compiled result:\n", jit_module.forward(np_t, np_t))
|
||||||
|
print("\nExpected result:\n", module.forward(t, t))
|
|
@ -0,0 +1,58 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
#
|
||||||
|
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||||
|
|
||||||
|
from typing import Iterable, Union
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from torch_mlir import ir
|
||||||
|
from torch_mlir.dialects import builtin
|
||||||
|
from .torch_mlir_types import TorchTensorType, PythonType
|
||||||
|
|
||||||
|
class Annotation:
|
||||||
|
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
||||||
|
self.types = list(map(lambda t:
|
||||||
|
PythonType(t) if isinstance(t, type) else t,
|
||||||
|
types))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
result = f'Annotation instance with {len(self.types)} types\n'
|
||||||
|
for e, type_ in enumerate(self.types):
|
||||||
|
result += f' Type of argument {e + 1}: {str(type_)}\n'
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.types)
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationConverter:
|
||||||
|
@staticmethod
|
||||||
|
def to_mlir_array_attr(annotation: Annotation,
|
||||||
|
context: ir.Context) -> ir.ArrayAttr:
|
||||||
|
dict_attrs = []
|
||||||
|
for type_ in annotation:
|
||||||
|
if not isinstance(type_, TorchTensorType):
|
||||||
|
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
||||||
|
continue
|
||||||
|
|
||||||
|
ir_type = type_.to_mlir(context)
|
||||||
|
with context:
|
||||||
|
type_attr = ir.TypeAttr.get(ir_type)
|
||||||
|
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr})
|
||||||
|
dict_attrs.append(dict_attr)
|
||||||
|
|
||||||
|
return ir.ArrayAttr.get(dict_attrs, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def annotate_forward_args(module: GraphModule,
|
||||||
|
types: Iterable[Union[TorchTensorType, type]]
|
||||||
|
) -> GraphModule:
|
||||||
|
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
|
||||||
|
for operand, type_ in zip(operands, types):
|
||||||
|
if isinstance(type_, type):
|
||||||
|
type_ = PythonType(type_)
|
||||||
|
operand.update_kwarg('torch_mlir_type', type_)
|
||||||
|
|
||||||
|
return module
|
|
@ -1,8 +1,9 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
#
|
||||||
|
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||||
"""
|
"""
|
||||||
-*- Python -*-
|
|
||||||
This file is licensed under a pytorch-style license
|
|
||||||
See frontends/pytorch/LICENSE for license information.
|
|
||||||
|
|
||||||
The following defines a set of classes for converting
|
The following defines a set of classes for converting
|
||||||
types used by Python and PyTorch into MLIR types from the
|
types used by Python and PyTorch into MLIR types from the
|
||||||
`torch` dialect.
|
`torch` dialect.
|
||||||
|
@ -16,8 +17,6 @@ Information about what types are supported by each class
|
||||||
can be found in docstrings of each of the classes.
|
can be found in docstrings of each of the classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import Any, Optional, Iterable
|
from typing import Any, Optional, Iterable
|
||||||
|
|
||||||
|
@ -62,6 +61,9 @@ class TorchTensorType(TorchMlirType):
|
||||||
err = "If shape is specified, dtype must also be specified"
|
err = "If shape is specified, dtype must also be specified"
|
||||||
raise TorchTensorTypeError(err)
|
raise TorchTensorTypeError(err)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'Torch Tensor (shape={self.shape}, dtype={self.dtype})'
|
||||||
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
if self.dtype is None:
|
if self.dtype is None:
|
||||||
return ir.Type.parse('!torch.tensor', context=context)
|
return ir.Type.parse('!torch.tensor', context=context)
|
||||||
|
@ -90,6 +92,9 @@ class TorchNnModuleType(TorchMlirType):
|
||||||
def __init__(self, module_name: str):
|
def __init__(self, module_name: str):
|
||||||
self.module_name = module_name
|
self.module_name = module_name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "torch.nn.Module"
|
||||||
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">',
|
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">',
|
||||||
context=context)
|
context=context)
|
||||||
|
@ -111,6 +116,9 @@ class PythonType(TorchMlirType):
|
||||||
def __init__(self, type_: Any):
|
def __init__(self, type_: Any):
|
||||||
self.type_ = type_
|
self.type_ = type_
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.type_)
|
||||||
|
|
||||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||||
asm = self._type_to_asm_dict.get(self.type_)
|
asm = self._type_to_asm_dict.get(self.type_)
|
||||||
if asm is None:
|
if asm is None:
|
|
@ -18,10 +18,17 @@ namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace TorchConversion {
|
namespace TorchConversion {
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
/// Creates a pipeline that lowers the object graph IR that is given by a
|
||||||
/// TorchScript import into the form expected by
|
/// TorchScript jit.ScriptModule into the form expected by
|
||||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||||
void createTorchScriptToLinalgOnTensorsBackendPipeline(
|
void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
||||||
|
OpPassManager &pm,
|
||||||
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
|
/// Creates a pipeline that lowers the object graph IR that is given by a
|
||||||
|
/// TorchScript jit.ScriptFunction into the form expected by
|
||||||
|
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||||
|
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
||||||
OpPassManager &pm,
|
OpPassManager &pm,
|
||||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
|
|
|
@ -56,8 +56,6 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
||||||
// calls, so inline everything.
|
// calls, so inline everything.
|
||||||
// TODO: Improve shape inference.
|
// TODO: Improve shape inference.
|
||||||
pm.addPass(createInlinerPass());
|
pm.addPass(createInlinerPass());
|
||||||
// Incorporate user annotations and remove signature Python-isms.
|
|
||||||
pm.addPass(createAdjustCallingConventionsPass());
|
|
||||||
|
|
||||||
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||||
}
|
}
|
||||||
|
@ -86,6 +84,9 @@ void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
||||||
// Please try to keep this list somewhat up to date when adding
|
// Please try to keep this list somewhat up to date when adding
|
||||||
// "optimize hard enough that it works" transformations.
|
// "optimize hard enough that it works" transformations.
|
||||||
|
|
||||||
|
// Incorporate user annotations and remove signature Python-isms.
|
||||||
|
pm.addPass(createAdjustCallingConventionsPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Inline global slots, which for most inference scenarios deletes them.
|
// Inline global slots, which for most inference scenarios deletes them.
|
||||||
// This also exposes more information to intraprocedural transformations
|
// This also exposes more information to intraprocedural transformations
|
||||||
|
|
|
@ -33,18 +33,17 @@ namespace {
|
||||||
void mlir::torch::registerTorchConversionPasses() {
|
void mlir::torch::registerTorchConversionPasses() {
|
||||||
::registerPasses();
|
::registerPasses();
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torchscript-to-linalg-on-tensors-backend-pipeline",
|
"torchscript-module-to-linalg-on-tensors-backend-pipeline",
|
||||||
"Pipeline lowering torch object graph to linalg-on-tensors backend format.",
|
"Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.",
|
||||||
mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline);
|
TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline);
|
||||||
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
|
"torchscript-function-to-linalg-on-tensors-backend-pipeline",
|
||||||
|
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.",
|
||||||
|
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline(
|
static void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
|
|
||||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
|
||||||
// backend contract.
|
|
||||||
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
|
||||||
|
|
||||||
// Check some invariants to catch errors in a clear way.
|
// Check some invariants to catch errors in a clear way.
|
||||||
pm.addPass(
|
pm.addPass(
|
||||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
@ -79,3 +78,21 @@ void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipe
|
||||||
// correct form.
|
// correct form.
|
||||||
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
||||||
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
|
|
||||||
|
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||||
|
// backend contract.
|
||||||
|
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
||||||
|
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
||||||
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
|
|
||||||
|
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||||
|
// backend contract.
|
||||||
|
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||||
|
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
||||||
|
}
|
||||||
|
|
|
@ -91,7 +91,7 @@ Diagnostics:
|
||||||
sys.stderr = StringIO()
|
sys.stderr = StringIO()
|
||||||
asm_for_error_report = mb.module.operation.get_asm(
|
asm_for_error_report = mb.module.operation.get_asm(
|
||||||
large_elements_limit=10, enable_debug_info=True)
|
large_elements_limit=10, enable_debug_info=True)
|
||||||
pipeline_str = "torchscript-to-linalg-on-tensors-backend-pipeline"
|
pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline"
|
||||||
# Lower module in place to make it ready for compiler backends.
|
# Lower module in place to make it ready for compiler backends.
|
||||||
with mb.module.context:
|
with mb.module.context:
|
||||||
pm = PassManager.parse(pipeline_str)
|
pm = PassManager.parse(pipeline_str)
|
||||||
|
|
Loading…
Reference in New Issue