The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
 
 
 
 
 
 
Go to file
Vivek Khandelwal 4a0bed0ce0
[ONNX] Add training mode support for BatchNormalization op (#3597)
This commit extends the OnnxToTorch lowering for BatchNormalization op
for supporting the case when training=True.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-14 10:46:38 +05:30
.github build: Update Roll PyTorch version (#3548) 2024-07-19 21:38:57 +05:30
build_tools build: Update Roll PyTorch version (#3548) 2024-07-19 21:38:57 +05:30
docs [torch-mlir] bump to llvm/llvm-project@9b78ddf3b2 (#3491) 2024-06-27 19:28:02 -07:00
externals Bump llvm to 585523750e2bbe374d1cb3bf4ff9d53de29b9593 (#3613) 2024-08-09 00:36:10 +08:00
include [Torch] emit upsample_nearest1d/2d/vec, and add shape/dtype functions (#3629) 2024-08-13 19:14:24 +08:00
lib [ONNX] Add training mode support for BatchNormalization op (#3597) 2024-08-14 10:46:38 +05:30
projects [torch] Support diagonal `einsum.Diagonal` (#3618) 2024-08-13 09:38:43 -07:00
python build: manually update PyTorch version (#3568) 2024-08-06 21:36:39 +05:30
test [ONNX] Add training mode support for BatchNormalization op (#3597) 2024-08-14 10:46:38 +05:30
tools Link necessary op interface implementations (#3364) 2024-06-03 19:43:28 -05:00
utils/bazel [Bazel] Add BuiltinDialectTdFiles dep to MLIRTorchOpsIncGen (#3430) 2024-06-07 05:02:17 -07:00
.clang-format
.git-blame-ignore-revs Add .git-blame-ignore-revs to allow ignoring sweeping formatting changes (#2823) 2024-01-29 10:29:51 -08:00
.gitignore [Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376) 2024-05-22 05:23:18 -07:00
.gitmodules
.pre-commit-config.yaml [NFC] Update black version (#3256) 2024-04-29 11:06:01 +08:00
.yamllint.yml Add `.yamllint` and disable some annoying recurring warnings on every pr (#3224) 2024-04-30 21:48:01 +00:00
CITATION.cff
CMakeLists.txt [Stablehlo] enable stablehlo's python extension binding (#3529) 2024-07-10 13:00:13 +08:00
LICENSE
README.md [FxImporter] Add an e2e test example for FxImporter (#3331) 2024-05-14 00:45:19 +08:00
build-requirements.txt
pyproject.toml Switch to pre-commit for lint checks. (#3200) 2024-04-27 13:29:51 -07:00
pytorch-hash.txt build: manually update PyTorch version (#3568) 2024-08-06 21:36:39 +05:30
pytorch-requirements.txt build: manually update PyTorch version (#3568) 2024-08-06 21:36:39 +05:30
requirements.txt
setup.py [NFC reformat] Applies pre-commit formatting to Python files. (#3244) 2024-04-27 14:16:31 -07:00
test-requirements.txt Bump Onnx Version to 1.16.1 (#3515) 2024-07-01 22:15:45 +05:30
torchvision-requirements.txt build: manually update PyTorch version (#3568) 2024-08-06 21:36:39 +05:30
whl-requirements.txt

README.md

The Torch-MLIR Project

The Torch-MLIR project aims to provide first class compiler support from the PyTorch ecosystem to the MLIR ecosystem.

This project is participating in the LLVM Incubator process: as such, it is not part of any official LLVM release. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project is not yet endorsed as a component of LLVM.

PyTorch PyTorch is an open source machine learning framework that facilitates the seamless transition from research and prototyping to production-level deployment.

MLIR The MLIR project offers a novel approach for building extensible and reusable compiler architectures, which address the issue of software fragmentation, reduce the cost of developing domain-specific compilers, improve compilation for heterogeneous hardware, and promote compatibility between existing compilers.

Torch-MLIR Several vendors have adopted MLIR as the middle layer in their systems, enabling them to map frameworks such as PyTorch, JAX, and TensorFlow into MLIR and subsequently lower them to their target hardware. We have observed half a dozen custom lowerings from PyTorch to MLIR, making it easier for hardware vendors to focus on their unique value, rather than needing to implement yet another PyTorch frontend for MLIR. The ultimate aim is to be similar to the current hardware vendors adding LLVM target support, rather than each one implementing Clang or a C++ frontend.

pre-commit

All the roads from PyTorch to Torch MLIR Dialect

We have few paths to lower down to the Torch MLIR Dialect.

Simplified Architecture Diagram for README

  • TorchScript This is the most tested path down to Torch MLIR Dialect.
  • LazyTensorCore Read more details here.
  • We also have basic TorchDynamo/PyTorch 2.0 support, see our long-term roadmap and Thoughts on PyTorch 2.0 for more details.

Project Communication

  • #torch-mlir channel on the LLVM Discord - this is the most active communication channel
  • Github issues here
  • torch-mlir section of LLVM Discourse

Meetings

Community Meeting / Developer Hour:

  • 1st and 3rd Monday of the month at 9 am PST
  • 2nd and 4th Monday of the month at 5 pm PST

Office Hours:

  • Every Thursday at 8:30 am PST

Meeting links can be found here.

Install torch-mlir snapshot

At the time of writing, we release pre-built snapshots of torch-mlir for Python 3.11 and Python 3.10.

If you have supported Python version, the following commands initialize a virtual environment.

python3.11 -m venv mlir_venv
source mlir_venv/bin/activate

Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.11.

conda create -n torch-mlir python=3.11
conda activate torch-mlir
python -m pip install --upgrade pip

Then, we can install torch-mlir with the corresponding torch and torchvision nightlies.

pip install --pre torch-mlir torchvision \
  --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels

Demos

FxImporter ResNet18

# Get the latest example if you haven't checked out the code
wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py

# Run ResNet18 as a standalone script.
python projects/pt1/examples/fximporter_resnet18.py

# Output
load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
...
PyTorch prediction
[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)]
torch-mlir prediction
[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)]

TorchScript ResNet18

Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend:

# Get the latest example if you haven't checked out the code
wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/torchscript_resnet18.py

# Run ResNet18 as a standalone script.
python projects/pt1/examples/torchscript_resnet18.py

load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/mlir/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100.0%
PyTorch prediction
[('Labrador retriever', 70.66319274902344), ('golden retriever', 4.956596374511719), ('Chesapeake Bay retriever', 4.195662975311279)]
torch-mlir prediction
[('Labrador retriever', 70.66320037841797), ('golden retriever', 4.956601619720459), ('Chesapeake Bay retriever', 4.195651531219482)]

Lazy Tensor Core

View examples here.

Repository Layout

The project follows the conventions of typical MLIR-based projects:

  • include/torch-mlir, lib structure for C++ MLIR compiler dialects/passes.
  • test for holding test code.
  • tools for torch-mlir-opt and such.
  • python top level directory for Python code

Developers

If you would like to develop and build torch-mlir from source please look at Development Notes