mirror of https://github.com/llvm/torch-mlir
82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
"""
|
|
Example of taking a Lazy Tensor computation and compiling it using torch-mlir.
|
|
|
|
This example depends on the Lazy Tensor Core (LTC) of PyTorch. For information
|
|
on how to obtain LTC, see here:
|
|
https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/QUICKSTART.md
|
|
|
|
To run the example, make sure the following are in your PYTHONPATH:
|
|
1. /path/to/torch-mlir/examples
|
|
2. /path/to/pytorch/lazy_tensor_core
|
|
3. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
|
|
|
|
then, simply call `python lazytensor_tanh.py`.
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
import lazy_tensor_core as ltc
|
|
from torch._C import CompilationUnit
|
|
|
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
|
|
import RefBackendLinalgOnTensorsBackend
|
|
from torch_mlir.passmanager import PassManager
|
|
|
|
from utils.annotator import Annotation
|
|
from utils.torch_mlir_types import TorchTensorType
|
|
from lazytensor.builder import build_module
|
|
|
|
ltc._LAZYC._ltc_init_ts_backend()
|
|
device = 'lazy'
|
|
dtype = torch.float32
|
|
shape = (2, 3)
|
|
|
|
x = torch.randn(shape, device=device, dtype=dtype)
|
|
y = torch.randn(shape, device=device, dtype=dtype)
|
|
|
|
def computation(x, y):
|
|
return y * x.tanh()
|
|
|
|
# Capture lazy computation and convert to TorchScript IR
|
|
graph_str = ltc._LAZYC._get_ltc_tensors_backend([computation(x, y)])
|
|
print("LAZY GRAPH")
|
|
print(graph_str)
|
|
graph = torch._C.parse_ir(graph_str)
|
|
|
|
# Create a torch.jit.ScriptFunction out of the graph
|
|
cu = CompilationUnit()
|
|
func_name = 'my_method'
|
|
script_function = cu.create_function(func_name, graph)
|
|
|
|
# `build_module` takes he torch.jit.ScriptFunction and the
|
|
# annotation on the operand types, and outputs an `ir.Module`
|
|
# with a single function representing the ScriptFunction in
|
|
# the torch MLIR dialect
|
|
func_annotation = Annotation([TorchTensorType(shape=shape, dtype=torch.float),
|
|
TorchTensorType(shape=shape, dtype=torch.float)])
|
|
mlir_module = build_module(script_function, func_annotation)
|
|
|
|
print("MLIR")
|
|
mlir_module.dump()
|
|
|
|
# Compile the torch MLIR and execute the compiled program
|
|
with mlir_module.context:
|
|
pm = PassManager.parse('torchscript-function-to-linalg-on-tensors-backend-pipeline')
|
|
pm.run(mlir_module)
|
|
|
|
print("BEFORE LINALG-ON-TENSORS BACKEND PIPELINE")
|
|
print(mlir_module)
|
|
|
|
backend = RefBackendLinalgOnTensorsBackend()
|
|
compiled = backend.compile(mlir_module)
|
|
jit_module = backend.load(compiled)
|
|
|
|
print("\n\nRunning Example Calculation")
|
|
print("Compiled result:")
|
|
print(jit_module.my_method(x.cpu().numpy(), y.cpu().numpy()))
|
|
print("Expected result:")
|
|
print(computation(x, y))
|