mirror of https://github.com/llvm/torch-mlir
Implement the lazytensor package (#331)
Implement the `lazytensor` python package for converting lazy computations captured by the Lazy Tensor Core into MLIR. This PR also fixes a few things with `torchfx` and its examplepull/337/head
parent
2b99c8b990
commit
b59f2cb673
31
README.md
31
README.md
|
@ -104,13 +104,40 @@ jupyter notebook
|
|||
|
||||
### TorchFX
|
||||
|
||||
TODO
|
||||
The `examples` folder includes the Python package `torchfx`, which is a functional prototype of a TorchFX to MLIR pipeline. The main entry point into the `torchfx` package is the `torchfx.builder` module, which includes a function for converting the output of a TorchFX trace into MLIR. Currently, the number of PyTorch operations supported is very limited, but will be expanded in the future.
|
||||
|
||||
#### Example usage of `torchfx`
|
||||
|
||||
The `examples` folder includes scripts `torchfx_*.py` showing how to use the TorchFX to MLIR pipeline. In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and add `/path/to/torch-mlir/examples` to your `PYTHONPATH`.
|
||||
|
||||
Then, run
|
||||
|
||||
```
|
||||
python torchfx_example_name.py
|
||||
```
|
||||
|
||||
replacing `torchfx_example_name.py` with the actual `torchfx` example you want to run.
|
||||
|
||||
|
||||
### Lazy Tensor Core
|
||||
|
||||
TODO
|
||||
The `examples` folder includes the Python package `lazytensor`, which implements a Lazy Tensor Core (LTC) to MLIR pipeline. The main entry point into the `lazytensor` package is the `lazytensor.builder`, which includes the function `build_module` that takes a computation captured and converted to TorchScript IR by LTC, and converts it to MLIR.
|
||||
|
||||
#### Example usage of `lazytensor`
|
||||
|
||||
The `examples` folder includes scripts `lazytensor_*.py` showing how to use the Lazy Tensor to MLIR pipeline. The examples depend on the Lazy Tensor Core (LTC) of PyTorch. For information on how to obtain LTC, see [here](https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md).
|
||||
|
||||
In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and also add the following to your `PYTHONPATH`:
|
||||
1. `/path/to/torch-mlir/examples`
|
||||
2. `/path/to/pytorch/lazy_tensor_core`
|
||||
|
||||
Then, run
|
||||
|
||||
```
|
||||
python lazytensor_example_name.py
|
||||
```
|
||||
|
||||
replacing `lazytensor_example_name.py` with the actual `lazytensor` example you want to run.
|
||||
|
||||
## Repository Layout
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""
|
||||
Translator from torch.jit.ScriptFunction to MLIR.
|
||||
|
||||
The following defines a function that take a torch.jit.ScriptFunction
|
||||
and converts it into an MLIR module.
|
||||
|
||||
The expected use for this module is to use the function
|
||||
`build_module(jit_function: torch.jit.ScriptFunction
|
||||
annotation: Annotation) -> ir.Module`
|
||||
to convert the TorchScript function into MLIR using the `torch`
|
||||
dialect.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch.jit import ScriptFunction
|
||||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
from torch_mlir.dialects.builtin import FuncOp
|
||||
from torch_mlir import ir
|
||||
|
||||
from utils.annotator import AnnotationConverter as ac
|
||||
from utils.annotator import Annotation
|
||||
|
||||
def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
|
||||
with module.context:
|
||||
name_attr = ir.StringAttr.get(name)
|
||||
for op in module.body.operations:
|
||||
if isinstance(op, FuncOp) and op.name == name_attr:
|
||||
return op
|
||||
|
||||
return None
|
||||
|
||||
def build_module(jit_function: ScriptFunction,
|
||||
annotation: Annotation) -> ir.Module:
|
||||
"""
|
||||
Translate input function into an MLIR module in the `torch` dialect.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
jit_function: ScriptFunction
|
||||
Function in TorchScript IR to turn into MLIR.
|
||||
annotation: Annotation
|
||||
Annotation object representing the types of
|
||||
the operands of `jit_function`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ir.Module
|
||||
Translation of the input module into an MLIR module
|
||||
"""
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(jit_function)
|
||||
|
||||
func_op = _get_func_op_with_name(mb.module, jit_function.name)
|
||||
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder'
|
||||
|
||||
arg_attrs = ac.to_mlir_array_attr(annotation, mb.context)
|
||||
func_op.attributes['arg_attrs'] = arg_attrs
|
||||
|
||||
return mb.module
|
|
@ -0,0 +1,81 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""
|
||||
Example of taking a Lazy Tensor computation and compiling it using torch-mlir.
|
||||
|
||||
This example depends on the Lazy Tensor Core (LTC) of PyTorch. For information
|
||||
on how to obtain LTC, see here:
|
||||
https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md
|
||||
|
||||
To run the example, make sure the following are in your PYTHONPATH:
|
||||
1. /path/to/torch-mlir/examples
|
||||
2. /path/to/pytorch/lazy_tensor_core
|
||||
3. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
|
||||
|
||||
then, simply call `python lazytensor_tanh.py`.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import lazy_tensor_core as ltc
|
||||
from torch._C import CompilationUnit
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
|
||||
import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir.passmanager import PassManager
|
||||
|
||||
from utils.annotator import Annotation
|
||||
from utils.torch_mlir_types import TorchTensorType
|
||||
from lazytensor.builder import build_module
|
||||
|
||||
ltc._LAZYC._ltc_init_ts_backend()
|
||||
device = 'lazy'
|
||||
dtype = torch.float32
|
||||
shape = (2, 3)
|
||||
|
||||
x = torch.randn(shape, device=device, dtype=dtype)
|
||||
y = torch.randn(shape, device=device, dtype=dtype)
|
||||
|
||||
def computation(x, y):
|
||||
return y * x.tanh()
|
||||
|
||||
# Capture lazy computation and convert to TorchScript IR
|
||||
graph_str = ltc._LAZYC._get_ltc_tensors_backend([computation(x, y)])
|
||||
print("LAZY GRAPH")
|
||||
print(graph_str)
|
||||
graph = torch._C.parse_ir(graph_str)
|
||||
|
||||
# Create a torch.jit.ScriptFunction out of the graph
|
||||
cu = CompilationUnit()
|
||||
func_name = 'my_method'
|
||||
script_function = cu.create_function(func_name, graph)
|
||||
|
||||
# `build_module` takes he torch.jit.ScriptFunction and the
|
||||
# annotation on the operand types, and outputs an `ir.Module`
|
||||
# with a single function representing the ScriptFunction in
|
||||
# the torch MLIR dialect
|
||||
func_annotation = Annotation([TorchTensorType(shape=shape, dtype=torch.float),
|
||||
TorchTensorType(shape=shape, dtype=torch.float)])
|
||||
mlir_module = build_module(script_function, func_annotation)
|
||||
|
||||
print("MLIR")
|
||||
mlir_module.dump()
|
||||
|
||||
# Compile the torch MLIR and execute the compiled program
|
||||
with mlir_module.context:
|
||||
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mlir_module)
|
||||
|
||||
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
|
||||
print(mlir_module)
|
||||
|
||||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(mlir_module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
print("\n\nRunning Example Calculation")
|
||||
print("Compiled result:")
|
||||
print(jit_module.my_method(x.cpu().numpy(), y.cpu().numpy()))
|
||||
print("Expected result:")
|
||||
print(computation(x, y))
|
|
@ -1,20 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
#
|
||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||
|
||||
from typing import Iterable, Union
|
||||
from torch.fx import GraphModule
|
||||
from .torch_mlir_types import TorchTensorType, PythonType
|
||||
|
||||
def annotate_forward_args(module: GraphModule,
|
||||
types: Iterable[Union[TorchTensorType, type]]
|
||||
) -> GraphModule:
|
||||
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
|
||||
for operand, type_ in zip(operands, types):
|
||||
if isinstance(type_, type):
|
||||
type_ = PythonType(type_)
|
||||
operand.update_kwarg('torch_mlir_type', type_)
|
||||
|
||||
return module
|
|
@ -1,3 +1,6 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""
|
||||
Translator from Torch.FX to MLIR.
|
||||
|
||||
|
@ -9,9 +12,6 @@ The expected use for this module is to use the function
|
|||
`build_module(py_module: torch.fx.GraphModule) -> ir.Module`
|
||||
to convert the output from the tracer into MLIR using the `torch`
|
||||
dialect.
|
||||
|
||||
This file is licensed under a pytorch-style license
|
||||
See frontends/pytorch/LICENSE for license information.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme
|
||||
|
@ -27,7 +27,8 @@ from torch_mlir.dialects import builtin, std
|
|||
import torch.fx
|
||||
from torch.fx.experimental.fx_acc import acc_ops
|
||||
|
||||
from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType
|
||||
from utils.torch_mlir_types import TorchTensorType, PythonType, \
|
||||
TorchNnModuleType
|
||||
|
||||
Environment = MutableMapping[torch.fx.Node, ir.Value]
|
||||
|
||||
|
@ -388,7 +389,7 @@ an argument named `input`'
|
|||
|
||||
|
||||
@_add_handler(ACC_OP_HANDLERS, acc_ops.add)
|
||||
def _add_handler(func_builder: _ForwardFunctionBuilder,
|
||||
def _add_tensor_handler(func_builder: _ForwardFunctionBuilder,
|
||||
args: Mapping[str, ir.Value]) -> ir.OpResult:
|
||||
input_arg = args.get('input')
|
||||
other_arg = args.get('other')
|
||||
|
@ -396,10 +397,10 @@ def _add_handler(func_builder: _ForwardFunctionBuilder,
|
|||
'A call to this handler must include an argument named `input` \
|
||||
and an argument named `other`'
|
||||
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
||||
int_type = PythonType(int).to_mlir(func_builder.context)
|
||||
attr_type = ir.Type.parse('i64', func_builder.context)
|
||||
int_attr = ir.IntegerAttr.get(attr_type, 1)
|
||||
alpha_arg = torch_d.ConstantIntOp(int_type,
|
||||
torch_int_type = PythonType(int).to_mlir(func_builder.context)
|
||||
int_type = ir.Type.parse("i64", context=func_builder.context)
|
||||
int_attr = ir.IntegerAttr.get(int_type, 1)
|
||||
alpha_arg = torch_d.ConstantIntOp(torch_int_type,
|
||||
int_attr,
|
||||
loc=func_builder.loc,
|
||||
ip=func_builder.func_ip).result
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# From the torch-mlir root, run with:
|
||||
# `python -m examples.torchfx.examples.example_add_tanh_sigmoid`
|
||||
# (after setting up python environment with write_env_file.sh)
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.fx_acc import acc_tracer
|
||||
import torch_mlir
|
||||
from torch_mlir.dialects.torch import register_dialect
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
from ..builder import build_module
|
||||
from ..annotator import annotate_forward_args
|
||||
from ..torch_mlir_types import TorchTensorType
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
# TODO: Debug issue with RefBackend
|
||||
#return torch.tanh(x) + torch.sigmoid(y)
|
||||
return torch.tanh(x)
|
||||
|
||||
|
||||
module = MyModule()
|
||||
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
|
||||
torch.Tensor(2,2)])
|
||||
|
||||
print("TRACE")
|
||||
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
|
||||
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
|
||||
print(traced_module.graph)
|
||||
torch_mlir_module = build_module(traced_module)
|
||||
|
||||
print("\n\nTORCH MLIR")
|
||||
torch_mlir_module.dump()
|
||||
print(torch_mlir_module.operation.verify())
|
||||
|
||||
with torch_mlir_module.context:
|
||||
pm = PassManager.parse('torchscript-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(torch_mlir_module)
|
||||
|
||||
print("\n\nLOWERED MLIR")
|
||||
torch_mlir_module.dump()
|
||||
|
||||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(torch_mlir_module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
print("\n\nRunning Forward Function")
|
||||
t = torch.rand((2, 2), dtype=torch.float)
|
||||
print("Compiled result:\n", jit_module.forward(t.numpy(), t.numpy()))
|
||||
print("\nExpected result:\n", module.forward(t, t))
|
|
@ -1,6 +1,7 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
#
|
||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# -*- Python -*-
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""
|
||||
Example of taking a moduled traced by TorchFX and compiling it using torch-mlir.
|
||||
|
||||
To run the example, make sure the following are in your PYTHONPATH:
|
||||
1. /path/to/torch-mlir/examples
|
||||
2. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
|
||||
|
||||
then, simply call `python torchfx_add_tanh_sigmoid.py`.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.fx.experimental.fx_acc import acc_tracer
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
|
||||
import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir.passmanager import PassManager
|
||||
|
||||
from torchfx.builder import build_module
|
||||
from utils.annotator import annotate_forward_args
|
||||
from utils.torch_mlir_types import TorchTensorType
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
# TODO: Debug issue with RefBackend
|
||||
#return torch.tanh(x) + torch.sigmoid(y)
|
||||
return torch.tanh(x)
|
||||
|
||||
|
||||
module = MyModule()
|
||||
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
|
||||
torch.Tensor(2,2)])
|
||||
|
||||
print("TRACE")
|
||||
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
|
||||
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
|
||||
print(traced_module.graph)
|
||||
mlir_module = build_module(traced_module)
|
||||
|
||||
print("\n\nTORCH MLIR")
|
||||
mlir_module.dump()
|
||||
print(mlir_module.operation.verify())
|
||||
|
||||
with mlir_module.context:
|
||||
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mlir_module)
|
||||
|
||||
print("\n\nLOWERED MLIR")
|
||||
mlir_module.dump()
|
||||
|
||||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(mlir_module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
print("\n\nRunning Forward Function")
|
||||
np_t = np.random.rand(2, 2).astype(dtype=np.float32)
|
||||
t = torch.tensor(np_t, dtype=torch.float)
|
||||
print("Compiled result:\n", jit_module.forward(np_t, np_t))
|
||||
print("\nExpected result:\n", module.forward(t, t))
|
|
@ -0,0 +1,58 @@
|
|||
# -*- Python -*-
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
#
|
||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||
|
||||
from typing import Iterable, Union
|
||||
from torch.fx import GraphModule
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import builtin
|
||||
from .torch_mlir_types import TorchTensorType, PythonType
|
||||
|
||||
class Annotation:
|
||||
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
||||
self.types = list(map(lambda t:
|
||||
PythonType(t) if isinstance(t, type) else t,
|
||||
types))
|
||||
|
||||
def __str__(self):
|
||||
result = f'Annotation instance with {len(self.types)} types\n'
|
||||
for e, type_ in enumerate(self.types):
|
||||
result += f' Type of argument {e + 1}: {str(type_)}\n'
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.types)
|
||||
|
||||
|
||||
class AnnotationConverter:
|
||||
@staticmethod
|
||||
def to_mlir_array_attr(annotation: Annotation,
|
||||
context: ir.Context) -> ir.ArrayAttr:
|
||||
dict_attrs = []
|
||||
for type_ in annotation:
|
||||
if not isinstance(type_, TorchTensorType):
|
||||
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
||||
continue
|
||||
|
||||
ir_type = type_.to_mlir(context)
|
||||
with context:
|
||||
type_attr = ir.TypeAttr.get(ir_type)
|
||||
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr})
|
||||
dict_attrs.append(dict_attr)
|
||||
|
||||
return ir.ArrayAttr.get(dict_attrs, context=context)
|
||||
|
||||
|
||||
def annotate_forward_args(module: GraphModule,
|
||||
types: Iterable[Union[TorchTensorType, type]]
|
||||
) -> GraphModule:
|
||||
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
|
||||
for operand, type_ in zip(operands, types):
|
||||
if isinstance(type_, type):
|
||||
type_ = PythonType(type_)
|
||||
operand.update_kwarg('torch_mlir_type', type_)
|
||||
|
||||
return module
|
|
@ -1,8 +1,9 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
#
|
||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||
"""
|
||||
-*- Python -*-
|
||||
This file is licensed under a pytorch-style license
|
||||
See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
The following defines a set of classes for converting
|
||||
types used by Python and PyTorch into MLIR types from the
|
||||
`torch` dialect.
|
||||
|
@ -16,8 +17,6 @@ Information about what types are supported by each class
|
|||
can be found in docstrings of each of the classes.
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||||
|
||||
import abc
|
||||
from typing import Any, Optional, Iterable
|
||||
|
||||
|
@ -62,6 +61,9 @@ class TorchTensorType(TorchMlirType):
|
|||
err = "If shape is specified, dtype must also be specified"
|
||||
raise TorchTensorTypeError(err)
|
||||
|
||||
def __str__(self):
|
||||
return f'Torch Tensor (shape={self.shape}, dtype={self.dtype})'
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
if self.dtype is None:
|
||||
return ir.Type.parse('!torch.tensor', context=context)
|
||||
|
@ -90,6 +92,9 @@ class TorchNnModuleType(TorchMlirType):
|
|||
def __init__(self, module_name: str):
|
||||
self.module_name = module_name
|
||||
|
||||
def __str__(self):
|
||||
return "torch.nn.Module"
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">',
|
||||
context=context)
|
||||
|
@ -111,6 +116,9 @@ class PythonType(TorchMlirType):
|
|||
def __init__(self, type_: Any):
|
||||
self.type_ = type_
|
||||
|
||||
def __str__(self):
|
||||
return str(self.type_)
|
||||
|
||||
def to_mlir(self, context: ir.Context) -> ir.Type:
|
||||
asm = self._type_to_asm_dict.get(self.type_)
|
||||
if asm is None:
|
|
@ -18,10 +18,17 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace TorchConversion {
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||
/// TorchScript import into the form expected by
|
||||
/// Creates a pipeline that lowers the object graph IR that is given by a
|
||||
/// TorchScript jit.ScriptModule into the form expected by
|
||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||
void createTorchScriptToLinalgOnTensorsBackendPipeline(
|
||||
void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
|
||||
/// Creates a pipeline that lowers the object graph IR that is given by a
|
||||
/// TorchScript jit.ScriptFunction into the form expected by
|
||||
/// torch-verify-linalg-on-tensors-verify-backend-contract.
|
||||
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
|
||||
|
|
|
@ -56,8 +56,6 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
|
|||
// calls, so inline everything.
|
||||
// TODO: Improve shape inference.
|
||||
pm.addPass(createInlinerPass());
|
||||
// Incorporate user annotations and remove signature Python-isms.
|
||||
pm.addPass(createAdjustCallingConventionsPass());
|
||||
|
||||
createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||
}
|
||||
|
@ -86,6 +84,9 @@ void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
|
|||
// Please try to keep this list somewhat up to date when adding
|
||||
// "optimize hard enough that it works" transformations.
|
||||
|
||||
// Incorporate user annotations and remove signature Python-isms.
|
||||
pm.addPass(createAdjustCallingConventionsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Inline global slots, which for most inference scenarios deletes them.
|
||||
// This also exposes more information to intraprocedural transformations
|
||||
|
|
|
@ -33,18 +33,17 @@ namespace {
|
|||
void mlir::torch::registerTorchConversionPasses() {
|
||||
::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch object graph to linalg-on-tensors backend format.",
|
||||
mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline);
|
||||
"torchscript-module-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.",
|
||||
TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline);
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-function-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.",
|
||||
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline);
|
||||
}
|
||||
|
||||
void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline(
|
||||
static void createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
|
||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||
// backend contract.
|
||||
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
||||
|
||||
// Check some invariants to catch errors in a clear way.
|
||||
pm.addPass(
|
||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||
|
@ -79,3 +78,21 @@ void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipe
|
|||
// correct form.
|
||||
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
|
||||
}
|
||||
|
||||
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
|
||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||
// backend contract.
|
||||
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
|
||||
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
||||
}
|
||||
|
||||
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
|
||||
// Conversion to the linalg-on-tensors backend contract starts from the Torch
|
||||
// backend contract.
|
||||
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options);
|
||||
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ Diagnostics:
|
|||
sys.stderr = StringIO()
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
pipeline_str = "torchscript-to-linalg-on-tensors-backend-pipeline"
|
||||
pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline"
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
|
|
Loading…
Reference in New Issue