The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
 
 
 
 
 
 
Go to file
武家伟 1106b9aeae
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
Co-authored-by: Vremold <xremold@gamil.com>
2022-08-23 16:44:36 -07:00
.github [MHLO] Init end to end unit tests (#1223) 2022-08-23 16:47:21 +08:00
build_tools Reenable LTC in out-of-tree build (for real this time) (#1205) 2022-08-19 15:25:00 -04:00
docs Update Torch-MLIR Architecture Diagram (#1254) 2022-08-22 13:09:32 -07:00
e2e_testing/torchscript [MHLO] Init end to end unit tests (#1223) 2022-08-23 16:47:21 +08:00
examples [MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266) 2022-08-23 16:44:36 -07:00
externals build: update llvm tag to 2dde4ba6 (#1229) 2022-08-15 23:54:45 -07:00
include Add a way for backends to control which ops are legal for them. 2022-08-22 14:16:13 -07:00
lib [MHLO] Init end to end unit tests (#1223) 2022-08-23 16:47:21 +08:00
python [MHLO] Init end to end unit tests (#1223) 2022-08-23 16:47:21 +08:00
test Add decomposition for aten.flatten.using_ints (#1161) 2022-08-23 11:52:54 +08:00
tools build: improve robustness of cmake and shell scripts (#1018) 2022-07-06 14:39:30 -07:00
utils/bazel [Bazel] Run buildifier (#1250) 2022-08-18 22:42:58 -07:00
.clang-format Add stub numpy dialect. 2020-04-26 17:20:58 -07:00
.gitignore Reference Lazy Backend (#1045) 2022-07-30 09:40:02 -04:00
.gitmodules s/external/externals/g (#1222) 2022-08-13 07:13:56 -07:00
.style.yapf Change preferred style to be PEP8 2022-04-20 14:38:19 -07:00
CMakeLists.txt Don't explicitly set MLIR_PDLL_TABLEGEN_EXE (#1262) 2022-08-22 16:45:56 +02:00
LICENSE Dual license the torch-mlir project. 2021-10-01 10:46:08 -07:00
README.md LTC Documentation (#1021) 2022-07-30 09:40:02 -04:00
Torch-MLIR.png Update Torch-MLIR Architecture Diagram (#1254) 2022-08-22 13:09:32 -07:00
development.md Dockerize and Cache Bazel {Local, CI} Builds (#1240) 2022-08-17 12:46:17 -07:00
pyproject.toml Minor buildsystem fixes (#778) 2022-04-21 15:53:00 -07:00
requirements.txt Add PyYaml to requirements.txt (#1174) 2022-08-11 17:59:39 +01:00
setup.py Reenable LTC in out-of-tree build (for real this time) (#1205) 2022-08-19 15:25:00 -04:00

README.md

The Torch-MLIR Project

The Torch-MLIR project aims to provide first class compiler support from the PyTorch ecosystem to the MLIR ecosystem.

This project is participating in the LLVM Incubator process: as such, it is not part of any official LLVM release. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project is not yet endorsed as a component of LLVM.

PyTorch An open source machine learning framework that accelerates the path from research prototyping to production deployment.

MLIR The MLIR project is a novel approach to building reusable and extensible compiler infrastructure. MLIR aims to address software fragmentation, improve compilation for heterogeneous hardware, significantly reduce the cost of building domain specific compilers, and aid in connecting existing compilers together.

Torch-MLIR Multiple Vendors use MLIR as the middle layer, mapping from platform frameworks like PyTorch, JAX, and TensorFlow into MLIR and then progressively lowering down to their target hardware. We have seen half a dozen custom lowerings from PyTorch to MLIR. Having canonical lowerings from the PyTorch ecosystem to the MLIR ecosystem would provide much needed relief to hardware vendors to focus on their unique value rather than implementing yet another PyTorch frontend for MLIR. The goal is to be similar to current hardware vendors adding LLVM target support instead of each one also implementing Clang / a C++ frontend.

Release Build

All the roads from PyTorch to Torch MLIR Dialect

We have few paths to lower down to the Torch MLIR Dialect.

Torch Lowering Architectures

  • TorchScript This is the most tested path down to Torch MLIR Dialect, and the PyTorch ecosystem is converging on using TorchScript IR as a lingua franca.
  • LazyTensorCore Read more details here.

Project Communication

  • #torch-mlir channel on the LLVM Discord - this is the most active communication channel
  • Github issues here
  • torch-mlir section of LLVM Discourse
  • Weekly meetings on Mondays 9AM PST. See here for more information.
  • Weekly op office hours on Thursdays 8:30-9:30AM PST. See here for more information.

Install torch-mlir snapshot

This installs a pre-built snapshot of torch-mlir for Python 3.7/3.8/3.9/3.10 on Linux and macOS.

python -m venv mlir_venv
source mlir_venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip
pip install --pre torch-mlir torchvision -f https://github.com/llvm/torch-mlir/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
# This will install the corresponding torch and torchvision nightlies

Demos

TorchScript ResNet18

Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend:

# Get the latest example if you haven't checked out the code
wget https://raw.githubusercontent.com/llvm/torch-mlir/main/examples/torchscript_resnet18.py

# Run ResNet18 as a standalone script.
python examples/torchscript_resnet18.py

load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/mlir/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100.0%
PyTorch prediction
[('Labrador retriever', 70.66319274902344), ('golden retriever', 4.956596374511719), ('Chesapeake Bay retriever', 4.195662975311279)]
torch-mlir prediction
[('Labrador retriever', 70.66320037841797), ('golden retriever', 4.956601619720459), ('Chesapeake Bay retriever', 4.195651531219482)]

Lazy Tensor Core

View examples here.

Eager Mode

Eager mode with TorchMLIR is a very experimental eager mode backend for PyTorch through the torch-mlir framework. Effectively, this mode works by compiling operator by operator as the NN is eagerly executed by PyTorch. This mode includes a fallback to conventional PyTorch if anything in the torch-mlir compilation process fails (e.g., unsupported operator). A simple example can be found at eager_mode.py. A ResNet18 example can be found at eager_mode_resnet18.py.

Repository Layout

The project follows the conventions of typical MLIR-based projects:

  • include/torch-mlir, lib structure for C++ MLIR compiler dialects/passes.
  • test for holding test code.
  • tools for torch-mlir-opt and such.
  • python top level directory for Python code

Developers

If you would like to develop and build torch-mlir from source please look at Development Notes