torch-mlir/examples/torchfx_add_tanh_sigmoid.py

68 lines
2.1 KiB
Python
Raw Normal View History

# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Example of taking a moduled traced by TorchFX and compiling it using torch-mlir.
To run the example, make sure the following are in your PYTHONPATH:
1. /path/to/torch-mlir/examples
2. /path/to/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir
then, simply call `python torchfx_add_tanh_sigmoid.py`.
"""
import torch
import numpy as np
from torch.fx.experimental.fx_acc import acc_tracer
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend \
import RefBackendLinalgOnTensorsBackend
from torch_mlir.passmanager import PassManager
from torchfx.builder import build_module
from utils.annotator import annotate_forward_args
from utils.torch_mlir_types import TorchTensorType
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# TODO: Debug issue with RefBackend
#return torch.tanh(x) + torch.sigmoid(y)
return torch.tanh(x)
module = MyModule()
traced_module = acc_tracer.trace(module, [torch.Tensor(2,2),
torch.Tensor(2,2)])
print("TRACE")
arg_type = TorchTensorType(shape=[None, None], dtype=torch.float)
traced_module = annotate_forward_args(traced_module, [arg_type, arg_type])
print(traced_module.graph)
mlir_module = build_module(traced_module)
print("\n\nTORCH MLIR")
mlir_module.dump()
print(mlir_module.operation.verify())
with mlir_module.context:
pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline')
pm.run(mlir_module)
print("\n\nLOWERED MLIR")
mlir_module.dump()
backend = RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(mlir_module)
jit_module = backend.load(compiled)
print("\n\nRunning Forward Function")
np_t = np.random.rand(2, 2).astype(dtype=np.float32)
t = torch.tensor(np_t, dtype=torch.float)
print("Compiled result:\n", jit_module.forward(np_t, np_t))
print("\nExpected result:\n", module.forward(t, t))