Implement the lazytensor package (#331)

Implement the `lazytensor` python package for converting
lazy computations captured by the Lazy Tensor Core into MLIR.
This PR also fixes a few things with `torchfx` and its example
pull/337/head
Ramiro Leal-Cavazos 2021-09-28 19:25:06 -05:00 committed by GitHub
parent 2b99c8b990
commit b59f2cb673
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 366 additions and 113 deletions

View File

@ -104,13 +104,40 @@ jupyter notebook
### TorchFX ### TorchFX
TODO The `examples` folder includes the Python package `torchfx`, which is a functional prototype of a TorchFX to MLIR pipeline. The main entry point into the `torchfx` package is the `torchfx.builder` module, which includes a function for converting the output of a TorchFX trace into MLIR. Currently, the number of PyTorch operations supported is very limited, but will be expanded in the future.
#### Example usage of `torchfx`
The `examples` folder includes scripts `torchfx_*.py` showing how to use the TorchFX to MLIR pipeline. In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and add `/path/to/torch-mlir/examples` to your `PYTHONPATH`.
Then, run
```
python torchfx_example_name.py
```
replacing `torchfx_example_name.py` with the actual `torchfx` example you want to run.
### Lazy Tensor Core ### Lazy Tensor Core
TODO The `examples` folder includes the Python package `lazytensor`, which implements a Lazy Tensor Core (LTC) to MLIR pipeline. The main entry point into the `lazytensor` package is the `lazytensor.builder`, which includes the function `build_module` that takes a computation captured and converted to TorchScript IR by LTC, and converts it to MLIR.
#### Example usage of `lazytensor`
The `examples` folder includes scripts `lazytensor_*.py` showing how to use the Lazy Tensor to MLIR pipeline. The examples depend on the Lazy Tensor Core (LTC) of PyTorch. For information on how to obtain LTC, see [here](https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md).
In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and also add the following to your `PYTHONPATH`:
1. `/path/to/torch-mlir/examples`
2. `/path/to/pytorch/lazy_tensor_core`
Then, run
```
python lazytensor_example_name.py
```
replacing `lazytensor_example_name.py` with the actual `lazytensor` example you want to run.
## Repository Layout ## Repository Layout

View File

@ -0,0 +1,64 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a function that take a torch.jit.ScriptFunction
and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction
annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch`
dialect.
"""
from typing import Optional
from torch.jit import ScriptFunction
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch_mlir.dialects.builtin import FuncOp
from torch_mlir import ir
from utils.annotator import AnnotationConverter as ac
from utils.annotator import Annotation
def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
return op
return None
def build_module(jit_function: ScriptFunction,
annotation: Annotation) -> ir.Module:
"""
Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
jit_function: ScriptFunction
Function in TorchScript IR to turn into MLIR.
annotation: Annotation
Annotation object representing the types of
the operands of `jit_function`.
Returns
-------
ir.Module
Translation of the input module into an MLIR module
"""
mb = ModuleBuilder()
mb.import_function(jit_function)
func_op = _get_func_op_with_name(mb.module, jit_function.name)
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder'
arg_attrs = ac.to_mlir_array_attr(annotation, mb.context)
func_op.attributes['arg_attrs'] = arg_attrs
return mb.module

View File

@ -0,0 +1,81 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Example of taking a Lazy Tensor computation and compiling it using torch-mlir.
This example depends on the Lazy Tensor Core (LTC) of PyTorch. For information
on how to obtain LTC, see here:
https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md
To run the example, make sure the following are in your PYTHONPATH:
1. /path/to/torch-mlir/examples
2. /path/to/pytorch/lazy_tensor_core
3. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
then, simply call `python lazytensor_tanh.py`.
"""
import numpy as np
import torch
import lazy_tensor_core as ltc
from torch._C import CompilationUnit
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
import RefBackendLinalgOnTensorsBackend
from torch_mlir.passmanager import PassManager
from utils.annotator import Annotation
from utils.torch_mlir_types import TorchTensorType
from lazytensor.builder import build_module
ltc._LAZYC._ltc_init_ts_backend()
device = 'lazy'
dtype = torch.float32
shape = (2, 3)
x = torch.randn(shape, device=device, dtype=dtype)
y = torch.randn(shape, device=device, dtype=dtype)
def computation(x, y):
return y * x.tanh()
# Capture lazy computation and convert to TorchScript IR
graph_str = ltc._LAZYC._get_ltc_tensors_backend([computation(x, y)])
print("LAZY GRAPH")
print(graph_str)
graph = torch._C.parse_ir(graph_str)
# Create a torch.jit.ScriptFunction out of the graph
cu = CompilationUnit()
func_name = 'my_method'
script_function = cu.create_function(func_name, graph)
# `build_module` takes he torch.jit.ScriptFunction and the
# annotation on the operand types, and outputs an `ir.Module`
# with a single function representing the ScriptFunction in
# the torch MLIR dialect
func_annotation = Annotation([TorchTensorType(shape=shape, dtype=torch.float),
TorchTensorType(shape=shape, dtype=torch.float)])
mlir_module = build_module(script_function, func_annotation)
print("MLIR")
mlir_module.dump()
# Compile the torch MLIR and execute the compiled program
with mlir_module.context:
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module)
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
print(mlir_module)
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Example Calculation")
print("Compiled result:")
print(jit_module.my_method(x.cpu().numpy(), y.cpu().numpy()))
print("Expected result:")
print(computation(x, y))

View File

@ -1,20 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
from typing import Iterable, Union
from torch.fx import GraphModule
from .torch_mlir_types import TorchTensorType, PythonType
def annotate_forward_args(module: GraphModule,
types: Iterable[Union[TorchTensorType, type]]
) -> GraphModule:
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
for operand, type_ in zip(operands, types):
if isinstance(type_, type):
type_ = PythonType(type_)
operand.update_kwarg('torch_mlir_type', type_)
return module

View File

@ -1,3 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
""" """
Translator from Torch.FX to MLIR. Translator from Torch.FX to MLIR.
@ -9,9 +12,6 @@ The expected use for this module is to use the function
`build_module(py_module: torch.fx.GraphModule) -> ir.Module` `build_module(py_module: torch.fx.GraphModule) -> ir.Module`
to convert the output from the tracer into MLIR using the `torch` to convert the output from the tracer into MLIR using the `torch`
dialect. dialect.
This file is licensed under a pytorch-style license
See frontends/pytorch/LICENSE for license information.
""" """
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme # pylint: disable=no-member, no-name-in-module, invalid-name, fixme
@ -27,7 +27,8 @@ from torch_mlir.dialects import builtin, std
import torch.fx import torch.fx
from torch.fx.experimental.fx_acc import acc_ops from torch.fx.experimental.fx_acc import acc_ops
from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType from utils.torch_mlir_types import TorchTensorType, PythonType, \
TorchNnModuleType
Environment = MutableMapping[torch.fx.Node, ir.Value] Environment = MutableMapping[torch.fx.Node, ir.Value]
@ -388,7 +389,7 @@ an argument named `input`'
@_add_handler(ACC_OP_HANDLERS, acc_ops.add) @_add_handler(ACC_OP_HANDLERS, acc_ops.add)
def _add_handler(func_builder: _ForwardFunctionBuilder, def _add_tensor_handler(func_builder: _ForwardFunctionBuilder,
args: Mapping[str, ir.Value]) -> ir.OpResult: args: Mapping[str, ir.Value]) -> ir.OpResult:
input_arg = args.get('input') input_arg = args.get('input')
other_arg = args.get('other') other_arg = args.get('other')
@ -396,10 +397,10 @@ def _add_handler(func_builder: _ForwardFunctionBuilder,
'A call to this handler must include an argument named `input` \ 'A call to this handler must include an argument named `input` \
and an argument named `other`' and an argument named `other`'
tensor_type = TorchTensorType().to_mlir(func_builder.context) tensor_type = TorchTensorType().to_mlir(func_builder.context)
int_type = PythonType(int).to_mlir(func_builder.context) torch_int_type = PythonType(int).to_mlir(func_builder.context)
attr_type = ir.Type.parse('i64', func_builder.context) int_type = ir.Type.parse("i64", context=func_builder.context)
int_attr = ir.IntegerAttr.get(attr_type, 1) int_attr = ir.IntegerAttr.get(int_type, 1)
alpha_arg = torch_d.ConstantIntOp(int_type, alpha_arg = torch_d.ConstantIntOp(torch_int_type,
int_attr, int_attr,
loc=func_builder.loc, loc=func_builder.loc,
ip=func_builder.func_ip).result ip=func_builder.func_ip).result

View File

@ -1,59 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# From the torch-mlir root, run with:
# `python -m examples.torchfx.examples.example_add_tanh_sigmoid`
# (after setting up python environment with write_env_file.sh)
import torch
from torch.fx.experimental.fx_acc import acc_tracer
import torch_mlir
from torch_mlir.dialects.torch import register_dialect
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from ..builder import build_module
from ..annotator import annotate_forward_args
from ..torch_mlir_types import TorchTensorType
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# TODO: Debug issue with RefBackend
#return torch.tanh(x) + torch.sigmoid(y)
return torch.tanh(x)
module = MyModule()
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
torch.Tensor(2,2)])
print("TRACE")
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
print(traced_module.graph)
torch_mlir_module = build_module(traced_module)
print("\n\nTORCH MLIR")
torch_mlir_module.dump()
print(torch_mlir_module.operation.verify())
with torch_mlir_module.context:
pm = PassManager.parse('torchscript-to-linalg-on-tensors-backend-pipeline')
pm.run(torch_mlir_module)
print("\n\nLOWERED MLIR")
torch_mlir_module.dump()
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(torch_mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Forward Function")
t = torch.rand((2, 2), dtype=torch.float)
print("Compiled result:\n", jit_module.forward(t.numpy(), t.numpy()))
print("\nExpected result:\n", module.forward(t, t))

View File

@ -1,6 +1,7 @@
# -*- Python -*- # -*- Python -*-
# This file is licensed under a pytorch-style license # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See frontends/pytorch/LICENSE for license information. # See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# #
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme # pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme

View File

@ -0,0 +1,67 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Example of taking a moduled traced by TorchFX and compiling it using torch-mlir.
To run the example, make sure the following are in your PYTHONPATH:
1. /path/to/torch-mlir/examples
2. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
then, simply call `python torchfx_add_tanh_sigmoid.py`.
"""
import torch
import numpy as np
from torch.fx.experimental.fx_acc import acc_tracer
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
import RefBackendLinalgOnTensorsBackend
from torch_mlir.passmanager import PassManager
from torchfx.builder import build_module
from utils.annotator import annotate_forward_args
from utils.torch_mlir_types import TorchTensorType
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# TODO: Debug issue with RefBackend
#return torch.tanh(x) + torch.sigmoid(y)
return torch.tanh(x)
module = MyModule()
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
torch.Tensor(2,2)])
print("TRACE")
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
print(traced_module.graph)
mlir_module = build_module(traced_module)
print("\n\nTORCH MLIR")
mlir_module.dump()
print(mlir_module.operation.verify())
with mlir_module.context:
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module)
print("\n\nLOWERED MLIR")
mlir_module.dump()
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Forward Function")
np_t = np.random.rand(2, 2).astype(dtype=np.float32)
t = torch.tensor(np_t, dtype=torch.float)
print("Compiled result:\n", jit_module.forward(np_t, np_t))
print("\nExpected result:\n", module.forward(t, t))

View File

@ -0,0 +1,58 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
from typing import Iterable, Union
from torch.fx import GraphModule
from torch_mlir import ir
from torch_mlir.dialects import builtin
from .torch_mlir_types import TorchTensorType, PythonType
class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(map(lambda t:
PythonType(t) if isinstance(t, type) else t,
types))
def __str__(self):
result = f'Annotation instance with {len(self.types)} types\n'
for e, type_ in enumerate(self.types):
result += f' Type of argument {e + 1}: {str(type_)}\n'
return result
def __iter__(self):
return iter(self.types)
class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation,
context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue
ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr})
dict_attrs.append(dict_attr)
return ir.ArrayAttr.get(dict_attrs, context=context)
def annotate_forward_args(module: GraphModule,
types: Iterable[Union[TorchTensorType, type]]
) -> GraphModule:
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
for operand, type_ in zip(operands, types):
if isinstance(type_, type):
type_ = PythonType(type_)
operand.update_kwarg('torch_mlir_type', type_)
return module

View File

@ -1,8 +1,9 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
""" """
-*- Python -*-
This file is licensed under a pytorch-style license
See frontends/pytorch/LICENSE for license information.
The following defines a set of classes for converting The following defines a set of classes for converting
types used by Python and PyTorch into MLIR types from the types used by Python and PyTorch into MLIR types from the
`torch` dialect. `torch` dialect.
@ -16,8 +17,6 @@ Information about what types are supported by each class
can be found in docstrings of each of the classes. can be found in docstrings of each of the classes.
""" """
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
import abc import abc
from typing import Any, Optional, Iterable from typing import Any, Optional, Iterable
@ -62,6 +61,9 @@ class TorchTensorType(TorchMlirType):
err = "If shape is specified, dtype must also be specified" err = "If shape is specified, dtype must also be specified"
raise TorchTensorTypeError(err) raise TorchTensorTypeError(err)
def __str__(self):
return f'Torch Tensor (shape={self.shape}, dtype={self.dtype})'
def to_mlir(self, context: ir.Context) -> ir.Type: def to_mlir(self, context: ir.Context) -> ir.Type:
if self.dtype is None: if self.dtype is None:
return ir.Type.parse('!torch.tensor', context=context) return ir.Type.parse('!torch.tensor', context=context)
@ -90,6 +92,9 @@ class TorchNnModuleType(TorchMlirType):
def __init__(self, module_name: str): def __init__(self, module_name: str):
self.module_name = module_name self.module_name = module_name
def __str__(self):
return "torch.nn.Module"
def to_mlir(self, context: ir.Context) -> ir.Type: def to_mlir(self, context: ir.Context) -> ir.Type:
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">', return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">',
context=context) context=context)
@ -111,6 +116,9 @@ class PythonType(TorchMlirType):
def __init__(self, type_: Any): def __init__(self, type_: Any):
self.type_ = type_ self.type_ = type_
def __str__(self):
return str(self.type_)
def to_mlir(self, context: ir.Context) -> ir.Type: def to_mlir(self, context: ir.Context) -> ir.Type:
asm = self._type_to_asm_dict.get(self.type_) asm = self._type_to_asm_dict.get(self.type_)
if asm is None: if asm is None:

View File

@ -18,10 +18,17 @@ namespace mlir {
namespace torch { namespace torch {
namespace TorchConversion { namespace TorchConversion {
/// Creates a pipeline that lowers the object graph IR that is produced by /// Creates a pipeline that lowers the object graph IR that is given by a
/// TorchScript import into the form expected by /// TorchScript jit.ScriptModule into the form expected by
/// torch-verify-linalg-on-tensors-verify-backend-contract. /// torch-verify-linalg-on-tensors-verify-backend-contract.
void createTorchScriptToLinalgOnTensorsBackendPipeline( void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers the object graph IR that is given by a
/// TorchScript jit.ScriptFunction into the form expected by
/// torch-verify-linalg-on-tensors-verify-backend-contract.
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options); const torch::Torch::TorchLoweringPipelineOptions &options);

View File

@ -56,8 +56,6 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
// calls, so inline everything. // calls, so inline everything.
// TODO: Improve shape inference. // TODO: Improve shape inference.
pm.addPass(createInlinerPass()); pm.addPass(createInlinerPass());
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
createGlobalizedModuleToTorchBackendPipeline(pm, options); createGlobalizedModuleToTorchBackendPipeline(pm, options);
} }
@ -86,6 +84,9 @@ void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
// Please try to keep this list somewhat up to date when adding // Please try to keep this list somewhat up to date when adding
// "optimize hard enough that it works" transformations. // "optimize hard enough that it works" transformations.
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
if (options.optimize) { if (options.optimize) {
// Inline global slots, which for most inference scenarios deletes them. // Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations // This also exposes more information to intraprocedural transformations

View File

@ -33,18 +33,17 @@ namespace {
void mlir::torch::registerTorchConversionPasses() { void mlir::torch::registerTorchConversionPasses() {
::registerPasses(); ::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-linalg-on-tensors-backend-pipeline", "torchscript-module-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch object graph to linalg-on-tensors backend format.", "Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.",
mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline); TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-function-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.",
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline);
} }
void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline( static void createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the linalg-on-tensors backend contract starts from the Torch
// backend contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
// Check some invariants to catch errors in a clear way. // Check some invariants to catch errors in a clear way.
pm.addPass( pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
@ -79,3 +78,21 @@ void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipe
// correct form. // correct form.
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
} }
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the linalg-on-tensors backend contract starts from the Torch
// backend contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
}
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the linalg-on-tensors backend contract starts from the Torch
// backend contract.
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options);
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
}

View File

@ -91,7 +91,7 @@ Diagnostics:
sys.stderr = StringIO() sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm( asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True) large_elements_limit=10, enable_debug_info=True)
pipeline_str = "torchscript-to-linalg-on-tensors-backend-pipeline" pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline"
# Lower module in place to make it ready for compiler backends. # Lower module in place to make it ready for compiler backends.
with mb.module.context: with mb.module.context:
pm = PassManager.parse(pipeline_str) pm = PassManager.parse(pipeline_str)