diff --git a/docs/Torch-MLIR_Architecture.png b/docs/Torch-MLIR_Architecture.png new file mode 100644 index 000000000..4a39d2419 Binary files /dev/null and b/docs/Torch-MLIR_Architecture.png differ diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 000000000..a814edc42 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,443 @@ +# Torch-MLIR Architecture + +## Introduction + +The Torch-MLIR project provides core infrastructure for bridging the PyTorch +ecosystem and the MLIR ecosystem. For example, Torch-MLIR enables PyTorch models +to be lowered to a few different MLIR dialects. Torch-MLIR does not attempt to +provide a production end-to-end flow for PyTorch programs by itself, but is a +useful component for constructing one. + +## Overview + +Torch-MLIR has two parts, which we call the "frontend" and "backend". These two +halves interface at an abstraction layer that we call the "backend contract", +which is a subset of the `torch` dialect with certain properties appealing for +backends to lower from. + +![Torch-MLIR Architecture](Torch-MLIR_Architecture.png) + +The frontend of Torch-MLIR is concerned with interfacing to PyTorch itself, and +then normalizing the program to the "backend contract". This part involves build +system complexity and exposure to PyTorch APIs to get the program into the MLIR +`torch` dialect. When we interface with TorchScript, we additionally have a +large amount of lowering and simplification to do within MLIR on the `torch` +dialect. + +The "backend" of Torch-MLIR takes IR in the "backend contract" form and lowers +it to various target dialects of interest to the MLIR ecosystem (various +"backends"). In particular, right now we support lowering to: + +- Linalg-on-Tensors (+ `arith`, `tensor`, etc.) +- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) +- [MHLO](https://github.com/tensorflow/mlir-hlo) + +The terms "frontend" and "backend" are highly overloaded in any compiler +project, but frequently in Torch-MLIR this is the meaning that they have. +Sometimes "frontend" can mean something even further up the stack, such as +something in PyTorch itself. When there is ambiguity we will refer to this as +"at the PyTorch level". Similarly, "backend" can sometimes refer to something +sitting below Linalg-on-Tensors, TOSA, or MHLO. + +## The `torch` dialect + +See [include/torch-mlir/Dialect/Torch/IR](https://github.com/llvm/torch-mlir/tree/main/include/torch-mlir/Dialect/Torch/IR) + +The central MLIR abstraction in the Torch-MLIR project is the `torch` dialect. +This dialect supports progressive lowering from the raw imported PyTorch +programs that various PyTorch integration points provide, all the way down to +the backend contract. + +The `torch` dialect must be versatile enough to support being imported by any +program capture mechanism in PyTorch -- this could be TorchDynamo, `torch.fx`, +LazyTensorCore, TorchScript, `torch.jit.trace`, etc. Thankfully, PyTorch is +factored such that we can handle this with one core import path, which is +through the PyTorch +"[JIT IR](https://github.com/pytorch/pytorch/blob/78c8a0d75220bdd4955415b5f81509e005af4232/torch/csrc/jit/OVERVIEW.md)", +and lives in +[torch-mlir/python/torch_mlir/dialects/torch/importer/jit_ir](https://github.com/llvm/torch-mlir/tree/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir). +The JIT IR is a highly principled IR that faithfully models a Python subset (+ +tensors, the PyTorch op registry, and a few other things). All the other PyTorch +program representations can eventually bottom-out on the JIT IR via some path +provided by PyTorch. The `torch` dialect is almost entirely in 1:1 +correspondence with the JIT IR -- this allows the importer to be extremely small +(the core is +[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp#L1)). + +### Ops + +See [TorchOps.td](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/include/torch-mlir/Dialect/Torch/IR/TorchOps.td#L1) + +The ops in the `torch` dialect are almost entirely generated based on the +PyTorch JIT IR operator registry via the script +[torch_ods_gen.py](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py#L1) (invoked via [update_torch_ods.sh](https://github.com/llvm/torch-mlir/blob/main/build_tools/update_torch_ods.sh)). +This script queries the registry and generates MLIR +[ODS](https://mlir.llvm.org/docs/OpDefinitions/) in +[GeneratedTorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L1). We have a guide for [adding a new op end-to-end](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation). + +There are also some manually implemented ops in the following categories (see +[TorchOps.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchOps.td#L1)): + +- Ops used for modeling PyTorch IValue object graphs (e.g. `torch.nn_module`, + `torch.class_type`). +- `torch.global_slot` and related ops which are used to model an incremental + lowering of the IValue object graphs. +- Ops that are supported in the JIT interpreter directly, and so don't have a + corresponding op in the registry (e.g. `torch.prim.If`, + `torch.prim.ListConstruct`, `torch.constant.*`) +- `torch.operator` which is used to represent ops from the registry which + haven't been generated by `torch_ods_gen.py`. + +### Types + +See [TorchTypes.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td#L1) + +The `torch` dialect has a complete set of types modeling the PyTorch type +system, which itself is a strongly typed subset of the Python type system (+ +tensors). These types are almost all 1:1 with the corresponding +[PyTorch types](https://github.com/pytorch/pytorch/blob/c54d18dbc7bb2f9fdd83c5de529702e5a02295c3/aten/src/ATen/core/jit_type.h#L1). + +The one exception where a significant amount of design work has been done in +Torch-MLIR is the handling of tensors. Torch-MLIR's tensor types allow +progressive lowering from raw imported IR which maybe be missing shapes, dtypes, +and value semantics, into the backend contract which provides those. Torch-MLIR +has two tensor types `ValueTensorType` (`!torch.vtensor`) and +`NonValueTensorType` (`!torch.tensor`) sharing most of their definition in +[TorchTypes.td](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td#L58). +The `NonValueTensorType` models a `torch.Tensor` including mutation, aliasing, +etc. while the `ValueTensorType` has value semantics. That is, `ValueTensorType` +is immutable and non-aliased. These types have a common C++ base class +[`BaseTensorType`](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h#L40) +which permits abstracting across them. Both `ValueTensorType` and +`NonValueTensorType` have an optional list of optional sizes and an optional +dtype. + +## The "backend contract" + +See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp#L151) + +The backend contract is a normalized form of the `torch` dialect with a set of +properties that make it easy to lower into various forms such as +Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the +box. The primary guarantees that we provide Torch-MLIR's backends are: + +- All tensors have been converted to value semantics. +- All tensors have at least a known number of dimensions (i.e. rank), and + ideally also have a precise size for each dimension. +- All tensors have a known dtype. +- Certain ops have been decomposed to make them easier to handle (this is + configurable). + +See the extensive comments in the function `satisfiesBackendContract` (and its +callees) in the `LowerToBackendContract` pass for an extended rationale for +these decisions, and a precise definition of the backend contract. + +## The Frontends + +Torch-MLIR provides 2 main frontends: + +- LazyTensorCore - a frontend that is based around intercepting PyTorch + dispatcher calls and creating a graph that is lazily evaluated on demand. +- TorchScript - a frontend based around importing TorchScript functions or + modules. Such modules or functions can be obtained via `torch.jit.script`, + `torch.jit.trace`, or a few other methods in the PyTorch ecosystem. + +Internally these share a lot of the core import code. + +### LazyTensorCore + +Docs: https://github.com/llvm/torch-mlir/blob/main/docs/ltc_backend.md + +LazyTensorCore (LTC) is a program capture method provided by PyTorch that does +device-level tracing. This low-level interception point sits below gradient +calculations, and is thus a good choice for training flows. The downside of LTC +is that it depends on having the whole PyTorch runtime available, so cannot be +used for ahead-of-time compilation or capturing standalone program artifacts. + +From an implementation perspective, the JIT IR that is produced by +LazyTensorCore has already had a number of transformations performed on it, in +particular, after importing from JIT IR to MLIR, the backend contract is +trivially satisfied. So the Torch-MLIR implementation complexity for +LazyTensorCore is restricted to build system and PyTorch integration, rather +than actual MLIR compiler passes. + +### TorchScript (`torch.jit.script`) + +[TorchScript](https://pytorch.org/docs/stable/jit.html) is a strict Python +subset which is modeled faithfully in the JIT IR. Additionally, TorchScript can +represent a full `torch.nn.Module` object graph (hierarchy). This results in a +significant amount of work needing to be done by the frontend to lower it to the +backend contract: + +- The `torch.nn.Module` hierarchy must be lowered to the backend contract, which + does not allow any program state. +- The program must be converted to value semantics (functionalized). +- Shapes and dtypes must be inferred. +- Many "Python-isms" must be simplified away, such as list appends, string + operations, etc. + +Because TorchScript does not naturally give shapes or dtypes, we usually require +the user to annotate a set of expected shapes and dtypes of any arguments. We then propagate those throughout the program. + +`torch.jit.trace` produces JIT IR with shapes and dtypes already, but no value +semantics. And often users want to erase the shapes in the trace to allow +dynamic shapes for the trace. Additionally, the Python-level data structures and +APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we +consider both of those as the same from the perspective of the responsibilities +of the compiler. Both are accessed via the `torch_mlir.compile` Python API. + +### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript + +PyTorch consistently models a subset of Python objects with its concept of +[`IValue`](https://github.com/pytorch/pytorch/blob/1ee9eb52b612f5fb4b63bbda832e44c8902edb64/aten/src/ATen/core/ivalue.h#L171) +(interpreter value). These are used throughout PyTorch to represent Python +values. When one `torch.jit.script`'s a `torch.nn.Module`, the result is +actually an `IValue` that represents the module, with a hierarchy of children +`IValue`'s. Strictly speaking, JIT IR `torch::jit::Graph`'s are only used to +represent the bodies of methods on the modules. So in addition to importing the +JIT IR, we also need to import the `IValue`'s. This happens inside [ivalue_importer.cpp](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp#L1). + +Most of the IValue modeling can reuse `torch` dialect ops that already exist +otherwise, such as `torch.constant.int` to represent an int in the object graph. +However, special IR constructs are needed for modeling the `torch.nn.Module`'s +themselves. + +An example is: + +```mlir +torch.class_type @c { + torch.attr "b" : !torch.bool + torch.attr "i" : !torch.int + torch.attr "f" : !torch.float + torch.attr "t" : !torch.tensor + torch.method "get_tensor", @get_tensor +} +func.func private @get_tensor(%arg0: !torch.nn.Module<"c">) -> !torch.tensor { + %2 = torch.prim.GetAttr %arg0["t"] : !torch.nn.Module<"c"> -> !torch.tensor + return %2 : !torch.tensor +} + +%true = torch.constant.bool true +%int3 = torch.constant.int 3 +%float4.250000e01 = torch.constant.float 4.250000e+01 +%0 = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor +%1 = torch.nn_module { + torch.slot "b", %true : !torch.bool + torch.slot "i", %int3 : !torch.int + torch.slot "f", %float4.250000e01 : !torch.float + torch.slot "t", %0 : !torch.tensor +} : !torch.nn.Module<"c"> +``` + +See the documentation for the ops for more information on the semantics of this +form. + + +### Lowering TorchScript to the backend contract + +The `torchscript-module-to-torch-backend-pipeline` contains the set of simplifications used convert TorchScript to the backend contract. At a high level, it consists of the following transformations: + +1. GlobalizeObjectGraph: This takes the `IValue` object graph and converts it + into a flat list of globals (see `torch.global_slot` and related ops). +1. LowerToBackendContract: This pass iteratively applies a simplification + pipeline until the backend contract is reached. The simplification pipeline consists of: + - Standard canonicalization. + - Shape refinement. See [shape_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/shape_lib.md) for detail + - DType refinement. See `RefineTypes`. + - Decomposing ops into more primitive ops. See `DecomposeComplexOps`. + +### Layering of the PyTorch Dependency + +One of the core principles of our Torch-MLIR <-> PyTorch interop is that +anything that links against PyTorch must interact with MLIR through +[the Torch-MLIR C API](https://github.com/llvm/torch-mlir/tree/main/include/torch-mlir-c). +This bypasses a number of very complex dependency and shared library issues. + +Additionally, we maintain the invariant that the core MLIR compiler code (in +`lib/` and `include/`) never has a build dependency on PyTorch itself. This +strict isolation avoids a number of complex dependency issues and ensures that +`torch-mlir-opt` and similar debugging tools always provide the excellent +development and debugging experience that MLIR developers expect. Sometimes, +certain highly stable enums and related logic must be shared with upstream +PyTorch, and for those we copy code from PyTorch into +[TorchUpstream.h](https://github.com/llvm/torch-mlir/blob/fde390c7669e29362b18388448ef2b188713383f/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L13). + +## The Backends + +Torch-MLIR provides 3 built-in backends, which take the backend contract IR and +lower it to the requirements of each backend. The 3 backends are: + +- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`, + `tensor`, etc.) +- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) +- [MHLO](https://github.com/tensorflow/mlir-hlo) + +### The Linalg Backend (Linalg-on-Tensors) + +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToLinalg + +The Linalg-on-Tensors backend was the first backend that we added, and it is +still the most complete. It fully supports dynamic shapes (known number of +dimensions but arbitrary dynamic dimension sizes). Since linalg was originally +designed as a dialect for transformations, it can be too low-level for certain +consumers. + +### The TOSA Backend + +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToTosa + +The TOSA backend was the second backend that we added. It remains preferred by +many users (especially "hardware" or "hardware-adjacent" folks). Some of its characteristics are: +- It is tied to a [spec](https://www.mlplatform.org/tosa/tosa_spec.html) with a + really clear "ISA-like" expository style that resonates with a lot of folks +- The coarse-grained named-op approach is a good match for the many compilers + that are designed that way. +- It has really good support for quantization / integer data types. +- It has clear versioning/stability guarantees on the op semantics. +- It is extremely solid with static shapes (and many of its users only care + about static shapes, so that's fine). + +### The MHLO Backend + +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo + +The MHLO backend was the third backend that we added, and it offers a reasonable +blend of the benefits of the other two. +- It is a coarse-grained named-op approach. +- It has a pretty clear spec for most of the ops (with a bit of mental + translation and hoping that MHLO is the same as HLO): + https://www.tensorflow.org/xla/operation_semantics +- It functionally supports dynamic shapes (though not as coherent and consistent + as Linalg-on-Tensors, and the dynamic shape support falls outside the + wonderful HLO docs above). +- It appears to be pretty tied to HLO (which is highly mature) so most of the op + surface area doesn't change too much. +- It has a different set of principles than TOSA which tend to make it more + expressive at the cost of having a larger abstraction gap from hardware. For + example, TOSA limits (for highly considered reasons) the number of dimensions + that certain operators can handle to 1D-4D, when from a purely algebraic + perspective there isn't a good reason to not be more general. Similarly, more + general forms of reduction and scatter also fall into MHLO nicely while + TOSA's principles tend to bias it away from that. + +### Backend Implementation + +All the backends are implemented using the MLIR [Dialect Conversion +infrastructure](https://mlir.llvm.org/docs/DialectConversion/). This involves +converting the `torch` dialect types to other types, so we closely follow the +principes from the "Type Conversions the Not-So-Hard Way" talk +([slides](https://drive.google.com/file/d/1FVbzCXxZzS9LBLuvpPNLWJD-XDkt54ky/view?usp=sharing), +[recording](https://drive.google.com/file/d/1VfVajitgf8ZPnd-HRkJvaJiFLhBsluXN/view?usp=sharing)). +We follow the standard `{include,lib}/Conversion/TorchTo*` convention used in +MLIR for conversion passes. + +For type conversion, we provide +[BackendTypeConversion.cpp](https://github.com/llvm/torch-mlir/blob/57681f794764a34c34e2be7f07f7dfbcafa683c1/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp#L1) +and +[BackendTypeConversionPasses.cpp](https://github.com/llvm/torch-mlir/blob/57681f794764a34c34e2be7f07f7dfbcafa683c1/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp#L1) +which provide a default conversion from `torch` dialect types to the builtin +`tensor` type and scalar integer/float types. These are not the right choice for +all backends, but can be copied and adapted by backends. These files closely +follow the "Type Conversions the Not-So-Hard Way" talk. + + +## Testing + +See +[development.md](https://github.com/llvm/torch-mlir/blob/9c8b96272057f4f8210de5842b6952228434cfa2/development.md#testing) +for more details on running tests. + +Torch-MLIR has two types of tests: + +1. End-to-end execution tests. These compile and run a program and check the + result against the expected output from execution on native Torch. These use + a homegrown testing framework (see + [framework.py](https://github.com/llvm/torch-mlir/blob/7d4a0d0e2b65c7ce8de19993f3b10ad5344fe32b/python/torch_mlir_e2e_test/torchscript/framework.py#L6)) + and the test suite lives at `python/torch_mlir_e2e_test/test_suite`. + +2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. + For example, these might involve using `torch-mlir-opt` to run a pass and + check the output with `FileCheck`. `lit` is flexible enough to unit test + various Python pieces, importers, and LTC this way as well. + +### Why so much end-to-end testing? + +Torch-MLIR places a heavy emphasis on end-to-end testing for the following reasons: + +Reason 1: Even if a compiler pass produces the output IR that the author +expected, that output IR may not correctly implement the semantics of the op. +This is especially true for complex, often-poorly-specified deep learning +operators that Torch-MLIR is mainly concerned with. It is critical to run these +against the source of truth to ensure correct implementation. + +Reason 2: There are many patterns in Torch-MLIR's backends that really just +expand one op into other ops without any real logic. When we started Torch-MLIR, +we were very religious about always having `.mlir` unit tests even for these +"macro expansion" patterns, but we found that these tests 1) Never caught a bug +2) Interfered with refactoring / caused spurious extra work (changing op syntax, +etc.). There is not much point to having a bunch of tests like this, which are +basically just rewriting the builder calls in a different syntax: + +``` +// MyPass.cpp +b.create(...) +b.create(...) + +// test.mlir +// CHECK: foo +// CHECK: bar +``` + +Such a test is simply checking that the implementation of an op is the way it +is. There is no way to change the implementation while having the test pass. So +the test is fully redundant with the implementation. + +Because of this, many Torch-MLIR patches adding support for new ops have no +`.mlir` unit tests, and only include end-to-end test(s). We generally make sure +that our end-to-end tests are as targeted as possible. As a result, when +debugging end-to-end test failures, the resulting reproducers (which our test +framework automaticaly produces for failures) are usually already fully reduced +test cases. + +### Do's and Don'ts for unit vs end-to-end testing. + +DO use an end-to-end test if you are implementing a new op or extending the +support for an existing op. + +DO use a unit test if your lowering for an op has multiple cases / logic. This +also helps future maintainers of the lowering to see in one place all the +different edge cases of the op that you had to handle. (these can be easily +reduced out of all the end-to-end tests you added). + +DON'T use a unit test if your lowering pattern could be described as a trivial +"macro expansion" of one op into another op or set of ops. That is, if you feel +like your unit test is just rewriting `b.create<...>(...)` into `CHECK: ...` +then it is probably not a useful unit test. + +DON'T add a unit test for trivial changes to RefineTypes. + +With the exceptions above, all changes should include appropriate unit tests, as +is standard in the LLVM and MLIR community. This includes full coverage of all +canonicalizations, pretty printing, passes, errors, and diagnostics. + +### The RefBackend (Reference Backend) + +In order to run end-to-end tests, Torch-MLIR needs an end-to-end flow. +Thankfully, upstream MLIR has just enough pieces to precariously put one +together that is enough for testing. + +The RefBackend consists of a few minor +[C++ passes](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/include/torch-mlir/RefBackend/Passes.td#L1) +filling in some corners missing upstream and +[Python glue logic](https://github.com/llvm/torch-mlir/blob/114f48e96c578ee76a6f83b3aa4aa229a8d5b76e/python/torch_mlir_e2e_test/linalg_on_tensors_bakends/refbackend.py#L1) +to pull together upstream functionality into a working system. + +The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the +ops and lowers them to loops. Note that TOSA and MHLO support lowering to +Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend. + +The RefBackend is absolutely not suitable for any production use case. It leaks +memory, doesn't support any error handling, performs no optimizations, and +probably a bunch of other horrible things. We are patiently awaiting for the +upstream MLIR community to produce a viable end-to-end flow with better +characteristics. diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index aa820a320..ae9dd4249 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -78,11 +78,8 @@ class AnyTorchTensorType If the type is `!torch.vtensor` then the tensor is restricted to operations that have value semantics ("v" = "value semantics"). This helps to maintain a strict separation between the value-semantic and potentially-mutating - worlds, as one of our main jobs in the compiler is to isolate the mutating - parts as much as possible because most lower levels of the compiler stack - are expected to require value semantics. E.g. many backend contracts - mostly use linalg-on-tensor for compute-heavy ops, which require - a conversion to the builtin `tensor` type which has value semantics. + worlds, as one of our main jobs in Torch-MLIR is to normalize the program + into a form with value semantics. Some notes about value semantics: - Using the type system described in PEP 483 (which TorchScript and other Python systems follow), `!torch.tensor` is a subtype of