mirror of https://github.com/llvm/torch-mlir
280 lines
15 KiB
Markdown
280 lines
15 KiB
Markdown
# Long-Term Roadmap for Torch-MLIR
|
||
|
||
## Overview
|
||
|
||
Latest update: 2022Q4
|
||
|
||
Torch-MLIR is about one year old now, and has successfully delivered a lot of
|
||
value to the community. In this document we outline the major architectural
|
||
changes that will make Torch-MLIR more robust, accessible, and useful to the
|
||
community on a 1-2 year timeline.
|
||
|
||
First, let's recap the goals of Torch-MLIR.
|
||
|
||
Technically, the goal of Torch-MLIR is to bridge the PyTorch and MLIR
|
||
ecosystems. That's vague, but it captures a very important property: Torch-MLIR
|
||
is not in the business of "innovating" either on the frontend or backend sides.
|
||
The project scope is to be an enabling connector between the two systems.
|
||
|
||
Non-technically, Torch-MLIR's goal is not to be an end-to-end product, but a
|
||
reliable piece of "off the shelf" infrastructure that system designers use as
|
||
part of their larger end-to-end flows. The main users are expected to be
|
||
"integrators", not end-users writing Python. This has the following facets:
|
||
- Community: Users of Torch-MLIR should feel empowered to participate in the
|
||
community to get their questions resolved, or propose (and even implement)
|
||
changes needed for their use cases.
|
||
- Ecosystem alignment: Users of Torch-MLIR should feel that the project is
|
||
aligned with all of the projects that it collaborates with, making it safe to
|
||
bet on for the long term.
|
||
- Ease of use: Users of Torch-MLIR should feel that it "Just Works", or that
|
||
when it fails, it fails in a way that is easy to understand, debug, and fix.
|
||
- Development: Torch-MLIR should be easy and convenient to develop.
|
||
|
||
Today, much of the design space and the main problems have been identified, but
|
||
larger-scale architectural and cross-project changes are needed to realize the
|
||
right long-term design. This will allow us to reach a steady-state that best
|
||
meets the goals above.
|
||
|
||
## The main architectural changes
|
||
|
||
As described in
|
||
[architecture.md](architecture.md),
|
||
Torch-MLIR can be split into two main parts: the "frontend" and the "backend".
|
||
|
||
The main sources of brittleness, maintenance cost, and duplicated work across
|
||
the ecosystem are:
|
||
|
||
- The frontend work required to lower TorchScript to the backend contract.
|
||
- The irregular support surface area of the large number of PyTorch ops across
|
||
the Linalg, TOSA, and StableHLO backends.
|
||
|
||
Most of this document describes long-term ecosystem changes that will address
|
||
these, drastically improving Torch-MLIR's ability to meet its goals.
|
||
|
||
## Current API Paths
|
||
|
||
Currently, there are two main API paths for the torch-mlir project:
|
||
|
||
- The first path is part of the legacy project pt1 code
|
||
(torch_mlir.torchscript.compile). This allows users to test the compiler's
|
||
output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`,
|
||
`RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to
|
||
the current generation work that is being driven via the fx_importer. It is
|
||
tied to the old Torchscript path.
|
||
- The second path (torch_mlir.fx.export_and_import) allows users to import a
|
||
consolidated torch.export.ExportedProgram instance of an arbitrary Python
|
||
callable (an nn.Module, a function or a method) and output to torch dialect
|
||
mlir module. This path is aligned with PyTorch's roadmap, but the path is
|
||
not fully functional yet.
|
||
|
||
## Roadmap
|
||
|
||
### Refactoring the frontend
|
||
|
||
The primary way to make the frontend more reliable is to leverage new PyTorch
|
||
infrastructure that bridges from the PyTorch eager world into compiler-land.
|
||
PyTorch has two main projects that together cover almost all user use cases and
|
||
provide a technically sound, high quality-of-implementation path from user
|
||
programs into the compiler.
|
||
|
||
- [TorchDynamo](https://github.com/pytorch/torchdynamo) - TorchDynamo uses
|
||
tracing-JIT-like techniques and program slicing to extract traces of tensor
|
||
operations, which can then be passed to lower-level compilers. It works
|
||
seamlessly with unmodified user programs.
|
||
- [FuncTorch](https://github.com/pytorch/functorch) - FuncTorch is basically JAX
|
||
for PyTorch. It requires manual program tracing and slicing, but that is
|
||
actually important for users since it gives them direct control over various
|
||
important transformations, such as `grad` and `vmap`.
|
||
|
||
These are both being heavily-invested-in by PyTorch core developers, and are
|
||
generally seen as the next generation of compiler technology for the project,
|
||
blending PyTorch's famous usability with excellent compiler integration
|
||
opportunities. Torch-MLIR works with these technologies as they exist today, but
|
||
significant work remains to enable wholesale deleting the high-maintenance parts
|
||
of Torch-MLIR. In the future, we expect the block diagram of Torch-MLIR to be
|
||
greatly simplified, as shown in the diagram below. Note that in the "Future"
|
||
side, PyTorch directly gives us IR in a form satisfying the backend contract.
|
||
|
||
![Roadmap of the frontend](images/roadmap_frontend.png)
|
||
|
||
The primary functional requirement of Torch-MLIR which remains unaddressed by
|
||
today's incarnation of TorchDynamo and FuncTorch is the support for dynamic
|
||
shapes. PyTorch core devs are
|
||
[heavily investing](https://dev-discuss.pytorch.org/t/state-of-symbolic-shapes-branch/777)
|
||
in this area, and both TorchDynamo and FuncTorch are being upgraded as PyTorch
|
||
rolls out its new symbolic shape infrastructure.
|
||
|
||
Smaller blockers are related to general API stability and usability of the
|
||
various pieces of PyTorch infra.
|
||
|
||
These blockers are expected to be addressed by the PyTorch core devs over time.
|
||
Torch-MLIR's role here is to communicate our requirements to PyTorch core and
|
||
align their roadmap and ours. We do this by maintaining connections with the
|
||
PyTorch core developers and being "good-citizen power users" of their latest
|
||
technology (i.e. trying things out, surfacing bugs, providing feedback, etc.).
|
||
|
||
Note: Because both TorchDynamo and FuncTorch are TorchFX-based, we could write a
|
||
direct TorchFX -> MLIR importer, and delete the TorchScript importer. This would
|
||
remove the need for Torch-MLIR to build its own custom Python extension --
|
||
Torch-MLIR would be a pure-Python user of the standard MLIR Python bindings.
|
||
There is no immediate rush for this though, since TorchFX can be converted to
|
||
TorchScript (this may become lossy as the dynamic shape support in PyTorch gets
|
||
more advanced).
|
||
|
||
### Refactoring the backend
|
||
|
||
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
|
||
TOSA, and StableHLO. These backends take IR in the backend contract form (see
|
||
[architecture.md](architecture.md)) and lowers them to the respective dialects.
|
||
Today, each backend is implemented completely independently. This leads to
|
||
duplication and irregularity across the backends.
|
||
|
||
Moving forward, we would like for the backends to share more code and for their
|
||
op support to be more aligned with each other. Since the backend contract today
|
||
includes "all" of PyTorch's operators, it is very costly to duplicate the
|
||
lowering of so many ops across backends. Additionally, there are 3
|
||
forward-looking efforts that intersect with this effort:
|
||
|
||
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
|
||
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
|
||
attractive to have "almost all" models bottleneck through a stable interface
|
||
like StableHLO. StableHLO is currently under relatively early development,
|
||
but already delivers on many of the goals of stability.
|
||
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
|
||
which could serve a role very similar to MHLO, while providing community
|
||
ownership. TCP is still in early planning phases, but there is strong
|
||
alignment with the StableHLO effort. One byproduct of TCP that is expected to
|
||
be very valuable is to incorporate the robust dynamic shape strategy from
|
||
Linalg into an MHLO-like dialect, and there is a strong desire from StableHLO
|
||
developers to adopt this once proven in TCP.
|
||
-
|
||
[PrimTorch](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577) -
|
||
this is an effort on the PyTorch side to decompose PyTorch operators into a
|
||
smaller set of primitive ops. This effort could effectively reduce the op
|
||
surface area at the Torch-MLIR level a lot, which would make the duplication
|
||
across backends less of an issue. But it still leaves open a lot of
|
||
questions, such as how to control decompositions.
|
||
|
||
This is overall less important than the frontend refactor, because it is "just
|
||
more work" for us as Torch-MLIR developers to support things in the current
|
||
infrastructure, while the frontend refactor directly affects the user
|
||
experience.
|
||
|
||
As the above efforts progress, we will need to make decisions about how to adopt
|
||
the various technologies. The main goal is consolidating the bottleneck point
|
||
where the O(100s-1000s) of ops in PyTorch are reduced to a more tractable O(100)
|
||
ops. There are two main ways to accomplish this:
|
||
|
||
- Future A: We concentrate the bottleneck step in the "Backend contract ->
|
||
StableHLO/MHLO/TCP" lowering path. This gets us a stable output for most
|
||
things. The cascaded/transitive lowerings then let us do O(100) lowerings from
|
||
then on down. (exact details are not worked out yet, and depend on e.g. TCP
|
||
adoption, etc.)
|
||
- Future B: PrimTorch concentrates the bottleneck step on the PyTorch side.
|
||
|
||
These two efforts synergize, but the need for cascaded lowerings is much less if
|
||
PrimTorch solves the decomposition problem on the PyTorch side.
|
||
|
||
![Roadmap of the backend](images/roadmap_backend.png)
|
||
|
||
One of the main blockers for doing cascaded lowerings today is the irregular
|
||
support for dynamic shapes across TOSA and MHLO. MHLO is much more complete, but
|
||
the use of `tensor<Nxindex>` to model shapes results in brittleness of the
|
||
system. A dynamic shape model like that being adopted in TCP (and presumably
|
||
StableHLO in time) would simplify this. Hence TCP is strategically important for
|
||
proving out a design for a "dynamically shaped MHLO-like thing" that doesn't
|
||
have this drawback.
|
||
|
||
### Tools for advanced AoT deployments
|
||
|
||
PyTorch's future direction is towards TorchDynamo and FuncTorch, which are
|
||
tracing-based systems. This means that they inherently struggle to capture
|
||
control flow and non-tensor computations. Many deployments, especially
|
||
Ahead-of-Time compiled ones such as for edge, require non-tensor computations.
|
||
It is extremely costly for people deploying such models to manually stitch
|
||
together graphs of traced functions with custom per-model code with existing
|
||
tools, and it is also very error-prone. We are awaiting movement on this front
|
||
from the PyTorch core team. There is some inspiration from systems like
|
||
[IREE-JAX](https://github.com/iree-org/iree-jax) in the JAX space for how to do
|
||
this, but ultimately this will depend on what the PyTorch core team decides on
|
||
for edge deployments. It is our responsibility to stay connected with them and
|
||
make sure that what they are building suits our needs.
|
||
|
||
### Project Governance / Structure
|
||
|
||
Torch-MLIR is currently an
|
||
[LLVM Incubator](https://llvm.org/docs/DeveloperPolicy.html#incubating-new-projects).
|
||
This has had the advantage of being organizationally close to MLIR Core.
|
||
However, the long-term direction is likely for Torch-MLIR to live under the
|
||
PyTorch umbrella, for a few reasons:
|
||
|
||
- As discussed in the other parts of this document, the long-term direction is
|
||
for Torch-MLIR to be a quite thin component, with much of the code being
|
||
obsoleted by infra in PyTorch core.
|
||
- The move towards more stable backend output formats will generally reduce
|
||
variance on the MLIR side. This means that MLIR will be the "more frozen" of
|
||
the two major dependencies (PyTorch and MLIR).
|
||
- We would like Torch-MLIR to be hooked into the PyTorch CI systems, and
|
||
generally be more tightly integrated with the PyTorch development process
|
||
(this includes things like packaging as well).
|
||
|
||
### Co-design
|
||
|
||
Many users of MLIR are developing advanced hardware or software systems, and
|
||
often these require information from the frontend beyond what PyTorch
|
||
provides today. Torch-MLIR should always be a "follower" of the features
|
||
available in the frontends and backends it connects to. We want to enable
|
||
co-design, of course, but new features such as quantization, sparsity,
|
||
distribution, etc. should be viewed from the lens of "the frontend can give us
|
||
X information, the backend needs Y information -- how do we connect
|
||
them?".
|
||
|
||
To satisfy those needs, we want to focus on existing extensibility mechanisms in
|
||
the frontend rather than inventing new ones. We intend to explore using existing
|
||
frontend concepts, such as
|
||
[custom ops](https://github.com/llvm/torch-mlir/issues/1462), to enable this
|
||
co-design.
|
||
|
||
If it proves to be absolutely necessary to add new concepts to the frontend
|
||
(e.g. new data types), it should be considered very carefully since supporting
|
||
such features is a major scope increase to the Torch-MLIR project. It is likely
|
||
to be better done in a separate project, with a carefully thought-out
|
||
integration with Torch-MLIR that avoids putting the maintenance burden on the
|
||
side of Torch-MLIR for the exploratory new frontend concept.
|
||
|
||
### LazyTensorCore support in Torch-MLIR
|
||
|
||
Today, Torch-MLIR supports LazyTensorCore. But as mentioned
|
||
[here](https://dev-discuss.pytorch.org/t/skipping-dispatcher-with-lazytensor/634/2?u=_sean_silva),
|
||
on the 1-2yr time horizon LTC will be more an implementation detail under
|
||
TorchDynamo for users that already have compilers written using LTC. That is,
|
||
LTC is basically just a way to convert a TorchDynamo FX graph into LTC graphs,
|
||
for users that have toolchains written against LTC graphs. But that won't make
|
||
much technical sense for Torch-MLIR, because we convert to MLIR in the end no
|
||
matter what. That is, in the future going
|
||
`TorchDynamo FX graph -> LTC Graph -> MLIR` can just be replaced by the direct
|
||
`TorchDynamo FX graph -> MLIR path`. So in the 1-2yr time horizon, LTC will not
|
||
make technical sense in Torch-MLIR.
|
||
|
||
There will still be non-technical blockers, such as if end-users have
|
||
`device='lazy'` hardcoded into their code. That will require a migration plan
|
||
for current LTC-based toolchains onto TorchDynamo. This migration will improve
|
||
the end-user experience since TorchDynamo is more seamless, but it is a
|
||
end-user-impacting migration nonetheless and we will want to phase it
|
||
appropriately with the community.
|
||
|
||
### End-to-end (E2E) testing
|
||
|
||
Torch-MLIR currently maintains its own test suite with
|
||
[hundreds of end-to-end tests](https://github.com/llvm/torch-mlir/tree/main/python/torch_mlir_e2e_test/test_suite)
|
||
that verify the correctness and completeness of our op lowerings.
|
||
These tests are tedious to write, and also sometimes hit corners
|
||
of PyTorch's API that aren't usually reachable by user code.
|
||
PyTorch already has an [end-to-end op test suite](https://github.com/pytorch/pytorch/blob/ead51864622467acd6835b6da86a166c1a32aa55/torch/testing/_internal/common_methods_invocations.py#L1)
|
||
and we should just plug into it. Here is [an example](https://github.com/pytorch/pytorch/blob/ead51864622467acd6835b6da86a166c1a32aa55/test/test_proxy_tensor.py#L1573) of doing so.
|
||
Even better, it would be great if TorchDynamo/PyTorch 2.0
|
||
directly provided a way to plug into this.
|
||
|
||
Additionally, we can leverage the [`pytorch-jit-paritybench`](https://github.com/jansel/pytorch-jit-paritybench)
|
||
to verify our end-to-end correctness on real models.
|