The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
 
 
 
 
 
 
Go to file
Ahmed S. Taei 8383497704
[NFC] Rename external -> externals (#699)
2022-03-26 09:12:27 -07:00
.github/workflows [NFC] Rename external -> externals (#699) 2022-03-26 09:12:27 -07:00
build_tools [NFC] Rename external -> externals (#699) 2022-03-26 09:12:27 -07:00
docs Introduce new shape library design. 2022-03-15 12:41:58 -07:00
e2e_testing/torchscript [tosa] Support for Aten[Unsqueeze|Contiguous|Dropout|Reshape|View] ops (#700) 2022-03-25 14:15:07 -07:00
examples Bump LLVM at 8361c5da30588d3d4a48eae648f53be1feb5cfad 2022-03-18 13:16:14 -04:00
externals [NFC] Rename external -> externals (#699) 2022-03-26 09:12:27 -07:00
include [LINALG] Add E2E support for `aten.zero_` op 2022-03-25 12:46:50 +05:30
lib [tosa] Support for Aten[Unsqueeze|Contiguous|Dropout|Reshape|View] ops (#700) 2022-03-25 14:15:07 -07:00
python Move e2e test definitions into the `torch_mlir_e2e_test` package 2022-03-25 13:56:41 -07:00
test [tosa] Support for Aten[Unsqueeze|Contiguous|Dropout|Reshape|View] ops (#700) 2022-03-25 14:15:07 -07:00
tools Bump LLVM at 8361c5da30588d3d4a48eae648f53be1feb5cfad 2022-03-18 13:16:14 -04:00
.clang-format Add stub numpy dialect. 2020-04-26 17:20:58 -07:00
.gitignore Add support for constant_pad_nd 2022-01-11 10:25:25 -05:00
.gitmodules [NFC] Rename external -> externals (#699) 2022-03-26 09:12:27 -07:00
.style.yapf Introduce a Target class and use it to define generic 32 and 64bit variants. 2020-06-13 14:43:10 -07:00
CMakeLists.txt [NFC] Rename external -> externals (#699) 2022-03-26 09:12:27 -07:00
LICENSE Dual license the torch-mlir project. 2021-10-01 10:46:08 -07:00
README.md [NFC] Rename external -> externals (#699) 2022-03-26 09:12:27 -07:00
Torch-MLIR.png Update Torch-MLIR architecture diagram 2022-03-22 11:51:52 -07:00
requirements.txt Set some wheel building optimization options. 2021-10-25 18:30:53 +00:00
setup.py Fix setup.py backwards compatibiity (#586) 2022-02-22 10:54:05 -05:00

README.md

torch-mlir

The Torch-MLIR project aims to provide first class compiler support from the PyTorch ecosystem to the MLIR ecosystem.

This project is participating in the LLVM Incubator process: as such, it is not part of any official LLVM release. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project is not yet endorsed as a component of LLVM.

PyTorch An open source machine learning framework that accelerates the path from research prototyping to production deployment.

MLIR The MLIR project is a novel approach to building reusable and extensible compiler infrastructure. MLIR aims to address software fragmentation, improve compilation for heterogeneous hardware, significantly reduce the cost of building domain specific compilers, and aid in connecting existing compilers together.

Torch-MLIR Multiple Vendors use MLIR as the middle layer mapping from platform frameworks like PyTorch, JAX, TensorFlow onto MLIR and then progressively lower down to their target hardware. We have seen half a dozen custom lowerings from PyTorch to MLIR. Having canonical lowerings from the PyTorch ecosystem to the MLIR ecosystem would provide much needed relief to hardware vendors to focus on their unique value rather than implementing another PyTorch frontend for MLIR. It would be similar to current hardware vendors adding LLVM target support instead of each one also implementing the Clang/C++ frontend.

All the roads from PyTorch to Torch MLIR Dialect

We have few paths to lower down to the Torch MLIR Dialect.

Torch Lowering Architectures

  • Torchscript This is the most tested path down to Torch MLIR Dialect.
  • TorchFX This provides a path to lower from TorchFX down to MLIR. This a functional prototype that we expect to mature as TorchFX matures
  • Lazy Tensor Core (Based on lazy-tensor-core staging branch) This path provides the upcoming LTC path of capture. It is based of an unstable devel branch but is the closest way for you to adapt any existing torch_xla derivatives.
  • “ACAP” - Deprecated torch_xla based capture Mentioned here for completeness.

Project Communication

  • #torch-mlir channel on the LLVM Discord - this is the most active communication channel
  • Github issues here
  • torch-mlir section of LLVM Discourse

Check out the code

git clone https://github.com/llvm/torch-mlir
cd torch-mlir
git submodule update --init

Setup your Python VirtualEnvironment and Dependencies

python -m venv mlir_venv
source mlir_venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip
# Install latest PyTorch nightlies and build requirements.
python -m pip install -r requirements.txt

Build

cmake -GNinja -Bbuild \
  -DCMAKE_C_COMPILER=clang \
  -DCMAKE_CXX_COMPILER=clang++ \
  -DPython3_FIND_VIRTUALENV=ONLY \
  -DLLVM_ENABLE_PROJECTS=mlir \
  -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \
  -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
  -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/externals/llvm-external-projects/torch-mlir-dialects \
  -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
  -DLLVM_TARGETS_TO_BUILD=host \
  externals/llvm-project/llvm

# Additional quality of life CMake flags:
# Enable ccache:
#  -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
# Enable LLD (links in seconds compared to minutes)
# -DCMAKE_EXE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_MODULE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_SHARED_LINKER_FLAGS_INIT="-fuse-ld=lld"
# Use --ld-path= instead of -fuse-ld=lld for clang > 13

# Build just torch-mlir (not all of LLVM)
cmake --build build --target tools/torch-mlir/all

# Run unit tests.
cmake --build build --target check-torch-mlir

# Build everything (including LLVM)
cmake --build build

Demos

Setup Python Environment

export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples

TorchScript

Running execution (end-to-end) tests:

# Run E2E TorchScript tests. These compile and run the TorchScript program
# through torch-mlir with a simplified MLIR CPU backend we call RefBackend
python -m e2e_testing.torchscript.main --filter Conv2d --verbose

Example IR for a simple 1 layer MLP to show the compilation steps from TorchScript.

Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend:

# The example uses PIL and requests to get the image.
pip install requests pillow
# Run ResNet18 as a standalone script.
python examples/torchscript_resnet18_e2e.py

load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/mlir/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100.0%
PyTorch prediction
[('Labrador retriever', 70.66319274902344), ('golden retriever', 4.956596374511719), ('Chesapeake Bay retriever', 4.195662975311279)]
torch-mlir prediction
[('Labrador retriever', 70.66320037841797), ('golden retriever', 4.956601619720459), ('Chesapeake Bay retriever', 4.195651531219482)]

Jupyter notebook:

python -m ipykernel install --user --name=torch-mlir --env PYTHONPATH "$PYTHONPATH"
# Open in jupyter, and then navigate to
# `examples/resnet_inference.ipynb` and use the `torch-mlir` kernel to run.
jupyter notebook

TorchFX

The examples folder includes the Python package torchfx, which is a functional prototype of a TorchFX to MLIR pipeline. The main entry point into the torchfx package is the torchfx.builder module, which includes a function for converting the output of a TorchFX trace into MLIR. Currently, the number of PyTorch operations supported is very limited, but will be expanded in the future.

Example usage of torchfx

The examples folder includes scripts torchfx_*.py showing how to use the TorchFX to MLIR pipeline. In order to run the examples, make sure you've setup your PYTHONPATH by following the Setup Python Environment instructions.

Then, run

python torchfx_example_name.py

replacing torchfx_example_name.py with the actual torchfx example you want to run.

Lazy Tensor Core

The examples folder includes the Python package lazytensor, which implements a Lazy Tensor Core (LTC) to MLIR pipeline. The main entry point into the lazytensor package is the lazytensor.builder, which includes the function build_module that takes a computation captured and converted to TorchScript IR by LTC, and converts it to MLIR.

Example usage of lazytensor

The examples folder includes scripts lazytensor_*.py showing how to use the Lazy Tensor to MLIR pipeline. The examples depend on the Lazy Tensor Core (LTC) of PyTorch. For information on how to obtain LTC, see here.

In order to run the examples, make sure you've setup your PYTHONPATH by following the Setup Python Environment instructions, and also add /path/to/pytorch/lazy_tensor_core to your PYTHONPATH as shown below:

export PYTHONPATH=$PYTHONPATH:`/replace/with/path/to/pytorch/lazy_tensor_core`
python lazytensor_example_name.py

replacing lazytensor_example_name.py with the actual lazytensor example you want to run.

Repository Layout

The project follows the conventions of typical MLIR-based projects:

  • include/torch-mlir, lib structure for C++ MLIR compiler dialects/passes.
  • test for holding test code.
  • tools for torch-mlir-opt and such.
  • python top level directory for Python code

Interactive Use

The build_tools/write_env_file.sh script will output a .env file in the workspace folder with the correct PYTHONPATH set. This allows tools like VSCode to work by default for debugging. This file can also be manually source'd in a shell.

Build Python Packages

We have preliminary support for building Python packages. This can be done with the following commands:

python -m pip install --upgrade pip
python -m pip install -r requirements.txt
CMAKE_GENERATOR=Ninja python setup.py bdist_wheel