Implement the lazytensor package (#331)

Implement the `lazytensor` python package for converting
lazy computations captured by the Lazy Tensor Core into MLIR.
This PR also fixes a few things with `torchfx` and its example
pull/337/head
Ramiro Leal-Cavazos 2021-09-28 19:25:06 -05:00 committed by GitHub
parent 2b99c8b990
commit b59f2cb673
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 366 additions and 113 deletions

View File

@ -104,13 +104,40 @@ jupyter notebook
### TorchFX
TODO
The `examples` folder includes the Python package `torchfx`, which is a functional prototype of a TorchFX to MLIR pipeline. The main entry point into the `torchfx` package is the `torchfx.builder` module, which includes a function for converting the output of a TorchFX trace into MLIR. Currently, the number of PyTorch operations supported is very limited, but will be expanded in the future.
#### Example usage of `torchfx`
The `examples` folder includes scripts `torchfx_*.py` showing how to use the TorchFX to MLIR pipeline. In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and add `/path/to/torch-mlir/examples` to your `PYTHONPATH`.
Then, run
```
python torchfx_example_name.py
```
replacing `torchfx_example_name.py` with the actual `torchfx` example you want to run.
### Lazy Tensor Core
TODO
The `examples` folder includes the Python package `lazytensor`, which implements a Lazy Tensor Core (LTC) to MLIR pipeline. The main entry point into the `lazytensor` package is the `lazytensor.builder`, which includes the function `build_module` that takes a computation captured and converted to TorchScript IR by LTC, and converts it to MLIR.
#### Example usage of `lazytensor`
The `examples` folder includes scripts `lazytensor_*.py` showing how to use the Lazy Tensor to MLIR pipeline. The examples depend on the Lazy Tensor Core (LTC) of PyTorch. For information on how to obtain LTC, see [here](https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md).
In order to run the examples, make sure you've setup your `PYTHONPATH` by following [these](#setup-env) instructions, and also add the following to your `PYTHONPATH`:
1. `/path/to/torch-mlir/examples`
2. `/path/to/pytorch/lazy_tensor_core`
Then, run
```
python lazytensor_example_name.py
```
replacing `lazytensor_example_name.py` with the actual `lazytensor` example you want to run.
## Repository Layout

View File

@ -0,0 +1,64 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Translator from torch.jit.ScriptFunction to MLIR.
The following defines a function that take a torch.jit.ScriptFunction
and converts it into an MLIR module.
The expected use for this module is to use the function
`build_module(jit_function: torch.jit.ScriptFunction
annotation: Annotation) -> ir.Module`
to convert the TorchScript function into MLIR using the `torch`
dialect.
"""
from typing import Optional
from torch.jit import ScriptFunction
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch_mlir.dialects.builtin import FuncOp
from torch_mlir import ir
from utils.annotator import AnnotationConverter as ac
from utils.annotator import Annotation
def _get_func_op_with_name(module: ir.Module, name: str) -> Optional[FuncOp]:
with module.context:
name_attr = ir.StringAttr.get(name)
for op in module.body.operations:
if isinstance(op, FuncOp) and op.name == name_attr:
return op
return None
def build_module(jit_function: ScriptFunction,
annotation: Annotation) -> ir.Module:
"""
Translate input function into an MLIR module in the `torch` dialect.
Parameters
----------
jit_function: ScriptFunction
Function in TorchScript IR to turn into MLIR.
annotation: Annotation
Annotation object representing the types of
the operands of `jit_function`.
Returns
-------
ir.Module
Translation of the input module into an MLIR module
"""
mb = ModuleBuilder()
mb.import_function(jit_function)
func_op = _get_func_op_with_name(mb.module, jit_function.name)
assert func_op is not None, 'Unable to find FuncOp in new module. Make sure function was imported correctly into ModuleBuilder'
arg_attrs = ac.to_mlir_array_attr(annotation, mb.context)
func_op.attributes['arg_attrs'] = arg_attrs
return mb.module

View File

@ -0,0 +1,81 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Example of taking a Lazy Tensor computation and compiling it using torch-mlir.
This example depends on the Lazy Tensor Core (LTC) of PyTorch. For information
on how to obtain LTC, see here:
https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md
To run the example, make sure the following are in your PYTHONPATH:
1. /path/to/torch-mlir/examples
2. /path/to/pytorch/lazy_tensor_core
3. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
then, simply call `python lazytensor_tanh.py`.
"""
import numpy as np
import torch
import lazy_tensor_core as ltc
from torch._C import CompilationUnit
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
import RefBackendLinalgOnTensorsBackend
from torch_mlir.passmanager import PassManager
from utils.annotator import Annotation
from utils.torch_mlir_types import TorchTensorType
from lazytensor.builder import build_module
ltc._LAZYC._ltc_init_ts_backend()
device = 'lazy'
dtype = torch.float32
shape = (2, 3)
x = torch.randn(shape, device=device, dtype=dtype)
y = torch.randn(shape, device=device, dtype=dtype)
def computation(x, y):
return y * x.tanh()
# Capture lazy computation and convert to TorchScript IR
graph_str = ltc._LAZYC._get_ltc_tensors_backend([computation(x, y)])
print("LAZY GRAPH")
print(graph_str)
graph = torch._C.parse_ir(graph_str)
# Create a torch.jit.ScriptFunction out of the graph
cu = CompilationUnit()
func_name = 'my_method'
script_function = cu.create_function(func_name, graph)
# `build_module` takes he torch.jit.ScriptFunction and the
# annotation on the operand types, and outputs an `ir.Module`
# with a single function representing the ScriptFunction in
# the torch MLIR dialect
func_annotation = Annotation([TorchTensorType(shape=shape, dtype=torch.float),
TorchTensorType(shape=shape, dtype=torch.float)])
mlir_module = build_module(script_function, func_annotation)
print("MLIR")
mlir_module.dump()
# Compile the torch MLIR and execute the compiled program
with mlir_module.context:
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module)
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
print(mlir_module)
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Example Calculation")
print("Compiled result:")
print(jit_module.my_method(x.cpu().numpy(), y.cpu().numpy()))
print("Expected result:")
print(computation(x, y))

View File

@ -1,20 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
from typing import Iterable, Union
from torch.fx import GraphModule
from .torch_mlir_types import TorchTensorType, PythonType
def annotate_forward_args(module: GraphModule,
types: Iterable[Union[TorchTensorType, type]]
) -> GraphModule:
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
for operand, type_ in zip(operands, types):
if isinstance(type_, type):
type_ = PythonType(type_)
operand.update_kwarg('torch_mlir_type', type_)
return module

View File

@ -1,3 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Translator from Torch.FX to MLIR.
@ -9,9 +12,6 @@ The expected use for this module is to use the function
`build_module(py_module: torch.fx.GraphModule) -> ir.Module`
to convert the output from the tracer into MLIR using the `torch`
dialect.
This file is licensed under a pytorch-style license
See frontends/pytorch/LICENSE for license information.
"""
# pylint: disable=no-member, no-name-in-module, invalid-name, fixme
@ -27,7 +27,8 @@ from torch_mlir.dialects import builtin, std
import torch.fx
from torch.fx.experimental.fx_acc import acc_ops
from .torch_mlir_types import TorchTensorType, PythonType, TorchNnModuleType
from utils.torch_mlir_types import TorchTensorType, PythonType, \
TorchNnModuleType
Environment = MutableMapping[torch.fx.Node, ir.Value]
@ -388,7 +389,7 @@ an argument named `input`'
@_add_handler(ACC_OP_HANDLERS, acc_ops.add)
def _add_handler(func_builder: _ForwardFunctionBuilder,
def _add_tensor_handler(func_builder: _ForwardFunctionBuilder,
args: Mapping[str, ir.Value]) -> ir.OpResult:
input_arg = args.get('input')
other_arg = args.get('other')
@ -396,10 +397,10 @@ def _add_handler(func_builder: _ForwardFunctionBuilder,
'A call to this handler must include an argument named `input` \
and an argument named `other`'
tensor_type = TorchTensorType().to_mlir(func_builder.context)
int_type = PythonType(int).to_mlir(func_builder.context)
attr_type = ir.Type.parse('i64', func_builder.context)
int_attr = ir.IntegerAttr.get(attr_type, 1)
alpha_arg = torch_d.ConstantIntOp(int_type,
torch_int_type = PythonType(int).to_mlir(func_builder.context)
int_type = ir.Type.parse("i64", context=func_builder.context)
int_attr = ir.IntegerAttr.get(int_type, 1)
alpha_arg = torch_d.ConstantIntOp(torch_int_type,
int_attr,
loc=func_builder.loc,
ip=func_builder.func_ip).result

View File

@ -1,59 +0,0 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# From the torch-mlir root, run with:
# `python -m examples.torchfx.examples.example_add_tanh_sigmoid`
# (after setting up python environment with write_env_file.sh)
import torch
from torch.fx.experimental.fx_acc import acc_tracer
import torch_mlir
from torch_mlir.dialects.torch import register_dialect
from torch_mlir.passmanager import PassManager
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from ..builder import build_module
from ..annotator import annotate_forward_args
from ..torch_mlir_types import TorchTensorType
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# TODO: Debug issue with RefBackend
#return torch.tanh(x) + torch.sigmoid(y)
return torch.tanh(x)
module = MyModule()
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
torch.Tensor(2,2)])
print("TRACE")
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
print(traced_module.graph)
torch_mlir_module = build_module(traced_module)
print("\n\nTORCH MLIR")
torch_mlir_module.dump()
print(torch_mlir_module.operation.verify())
with torch_mlir_module.context:
pm = PassManager.parse('torchscript-to-linalg-on-tensors-backend-pipeline')
pm.run(torch_mlir_module)
print("\n\nLOWERED MLIR")
torch_mlir_module.dump()
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(torch_mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Forward Function")
t = torch.rand((2, 2), dtype=torch.float)
print("Compiled result:\n", jit_module.forward(t.numpy(), t.numpy()))
print("\nExpected result:\n", module.forward(t, t))

View File

@ -1,6 +1,7 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme

View File

@ -0,0 +1,67 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Example of taking a moduled traced by TorchFX and compiling it using torch-mlir.
To run the example, make sure the following are in your PYTHONPATH:
1. /path/to/torch-mlir/examples
2. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
then, simply call `python torchfx_add_tanh_sigmoid.py`.
"""
import torch
import numpy as np
from torch.fx.experimental.fx_acc import acc_tracer
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
import RefBackendLinalgOnTensorsBackend
from torch_mlir.passmanager import PassManager
from torchfx.builder import build_module
from utils.annotator import annotate_forward_args
from utils.torch_mlir_types import TorchTensorType
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# TODO: Debug issue with RefBackend
#return torch.tanh(x) + torch.sigmoid(y)
return torch.tanh(x)
module = MyModule()
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
torch.Tensor(2,2)])
print("TRACE")
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
print(traced_module.graph)
mlir_module = build_module(traced_module)
print("\n\nTORCH MLIR")
mlir_module.dump()
print(mlir_module.operation.verify())
with mlir_module.context:
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module)
print("\n\nLOWERED MLIR")
mlir_module.dump()
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Forward Function")
np_t = np.random.rand(2, 2).astype(dtype=np.float32)
t = torch.tensor(np_t, dtype=torch.float)
print("Compiled result:\n", jit_module.forward(np_t, np_t))
print("\nExpected result:\n", module.forward(t, t))

View File

@ -0,0 +1,58 @@
# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
from typing import Iterable, Union
from torch.fx import GraphModule
from torch_mlir import ir
from torch_mlir.dialects import builtin
from .torch_mlir_types import TorchTensorType, PythonType
class Annotation:
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
self.types = list(map(lambda t:
PythonType(t) if isinstance(t, type) else t,
types))
def __str__(self):
result = f'Annotation instance with {len(self.types)} types\n'
for e, type_ in enumerate(self.types):
result += f' Type of argument {e + 1}: {str(type_)}\n'
return result
def __iter__(self):
return iter(self.types)
class AnnotationConverter:
@staticmethod
def to_mlir_array_attr(annotation: Annotation,
context: ir.Context) -> ir.ArrayAttr:
dict_attrs = []
for type_ in annotation:
if not isinstance(type_, TorchTensorType):
dict_attrs.append(ir.DictAttr.get({}, context=context))
continue
ir_type = type_.to_mlir(context)
with context:
type_attr = ir.TypeAttr.get(ir_type)
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr})
dict_attrs.append(dict_attr)
return ir.ArrayAttr.get(dict_attrs, context=context)
def annotate_forward_args(module: GraphModule,
types: Iterable[Union[TorchTensorType, type]]
) -> GraphModule:
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
for operand, type_ in zip(operands, types):
if isinstance(type_, type):
type_ = PythonType(type_)
operand.update_kwarg('torch_mlir_type', type_)
return module

View File

@ -1,8 +1,9 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
"""
-*- Python -*-
This file is licensed under a pytorch-style license
See frontends/pytorch/LICENSE for license information.
The following defines a set of classes for converting
types used by Python and PyTorch into MLIR types from the
`torch` dialect.
@ -16,8 +17,6 @@ Information about what types are supported by each class
can be found in docstrings of each of the classes.
"""
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
import abc
from typing import Any, Optional, Iterable
@ -62,6 +61,9 @@ class TorchTensorType(TorchMlirType):
err = "If shape is specified, dtype must also be specified"
raise TorchTensorTypeError(err)
def __str__(self):
return f'Torch Tensor (shape={self.shape}, dtype={self.dtype})'
def to_mlir(self, context: ir.Context) -> ir.Type:
if self.dtype is None:
return ir.Type.parse('!torch.tensor', context=context)
@ -90,6 +92,9 @@ class TorchNnModuleType(TorchMlirType):
def __init__(self, module_name: str):
self.module_name = module_name
def __str__(self):
return "torch.nn.Module"
def to_mlir(self, context: ir.Context) -> ir.Type:
return ir.Type.parse(f'!torch.nn.Module<"{self.module_name}">',
context=context)
@ -111,6 +116,9 @@ class PythonType(TorchMlirType):
def __init__(self, type_: Any):
self.type_ = type_
def __str__(self):
return str(self.type_)
def to_mlir(self, context: ir.Context) -> ir.Type:
asm = self._type_to_asm_dict.get(self.type_)
if asm is None:

View File

@ -18,10 +18,17 @@ namespace mlir {
namespace torch {
namespace TorchConversion {
/// Creates a pipeline that lowers the object graph IR that is produced by
/// TorchScript import into the form expected by
/// Creates a pipeline that lowers the object graph IR that is given by a
/// TorchScript jit.ScriptModule into the form expected by
/// torch-verify-linalg-on-tensors-verify-backend-contract.
void createTorchScriptToLinalgOnTensorsBackendPipeline(
void createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers the object graph IR that is given by a
/// TorchScript jit.ScriptFunction into the form expected by
/// torch-verify-linalg-on-tensors-verify-backend-contract.
void createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
OpPassManager &pm,
const torch::Torch::TorchLoweringPipelineOptions &options);

View File

@ -56,8 +56,6 @@ void mlir::torch::Torch::createTorchScriptToTorchBackendPipeline(
// calls, so inline everything.
// TODO: Improve shape inference.
pm.addPass(createInlinerPass());
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
createGlobalizedModuleToTorchBackendPipeline(pm, options);
}
@ -86,6 +84,9 @@ void mlir::torch::Torch::createGlobalizedModuleToTorchBackendPipeline(
// Please try to keep this list somewhat up to date when adding
// "optimize hard enough that it works" transformations.
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
if (options.optimize) {
// Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations

View File

@ -33,18 +33,17 @@ namespace {
void mlir::torch::registerTorchConversionPasses() {
::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch object graph to linalg-on-tensors backend format.",
mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline);
"torchscript-module-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch object graph representing a torch.jit.ScriptModule to linalg-on-tensors backend format.",
TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-function-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering a flat list of functions representing a torch.jit.ScriptFunction to linalg-on-tensors backend format.",
TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline);
}
void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipeline(
static void createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the linalg-on-tensors backend contract starts from the Torch
// backend contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
// Check some invariants to catch errors in a clear way.
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
@ -79,3 +78,21 @@ void mlir::torch::TorchConversion::createTorchScriptToLinalgOnTensorsBackendPipe
// correct form.
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
}
void TorchConversion::createTorchScriptModuleToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the linalg-on-tensors backend contract starts from the Torch
// backend contract.
Torch::createTorchScriptToTorchBackendPipeline(pm, options);
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
}
void TorchConversion::createTorchScriptFunctionToLinalgOnTensorsBackendPipeline(
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
// Conversion to the linalg-on-tensors backend contract starts from the Torch
// backend contract.
Torch::createGlobalizedModuleToTorchBackendPipeline(pm, options);
createTorchBackendToLinalgOnTensorsBackendPipeline(pm, options);
}

View File

@ -91,7 +91,7 @@ Diagnostics:
sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
pipeline_str = "torchscript-to-linalg-on-tensors-backend-pipeline"
pipeline_str = "torchscript-module-to-linalg-on-tensors-backend-pipeline"
# Lower module in place to make it ready for compiler backends.
with mb.module.context:
pm = PassManager.parse(pipeline_str)