mirror of https://github.com/llvm/torch-mlir
72 lines
3.8 KiB
Markdown
72 lines
3.8 KiB
Markdown
# Adding a Shape Function
|
||
|
||
## Overview
|
||
|
||
As part of adding support for a Torch operator in Torch-MLIR, it is usually
|
||
necessary to define a shape function so that the compiler can infer the shapes
|
||
of result tensors for the operator. We use the [shape library](shape_lib.md) for this process.
|
||
|
||
## Step-by-step guide
|
||
|
||
We will use the example of adding support for the `torch.aten.tanh` op.
|
||
|
||
1. First, you need to find the shape function signature for the operator you are
|
||
implementing a shape function for. This can be found in
|
||
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt` generated
|
||
by the `build_tools/update_torch_ods.sh` script. That file is the "rosetta
|
||
stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`,
|
||
and the shape function signature
|
||
`def aten〇tanh(self: List[int]) -> List[int]:`. Note the use of `〇` as a
|
||
separator since `.` or `::` aren't legal in a Python identifier.
|
||
|
||
2. Paste the shape function signature into `shape_lib_gen.py` in an appropriate
|
||
place (ideally near other functions with a similar shape function). Note that
|
||
`shape_lib_gen.py` will check that this signature is verbatim identical with
|
||
the one given in `JITOperatorRegistryDump.txt` -- this ensures that the shape
|
||
functions don't get outdated if Torch changes an operator signature.
|
||
|
||
3. Fill in the body of the shape function. Ideally this will just be a call into
|
||
a helper function from `upstream_shape_helpers.py`. But in general, you will
|
||
need to write the shape function and test it (see the comments about "Shape
|
||
function testing infrastructure" in `shape_lib_gen.py`).
|
||
|
||
4. Re-run the `build_tools/update_shape_lib.sh` script to update the shape
|
||
library. After this step happens, ideally everything "just works" and the
|
||
shape is now correctly inferred for the operator.
|
||
|
||
## When things go wrong
|
||
|
||
It is possible that the shape refinement pipeline (see
|
||
[Shape Refinement Pipeline Architecture](shape_lib.md#shape-refinement-pipeline-architecture))
|
||
is not able to infer the shape of a tensor with a given shape function. This
|
||
usually means that there is something about the shape function which the
|
||
optimizations in `torch-simplify-shape-functions`
|
||
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`) cannot handle.
|
||
|
||
To debug this, the overall goal is to pinpoint the IR construct that is not
|
||
being simplified. This is usually accomplished by a combination of looking at
|
||
the Python code for the shape function and the IR dumps. The best IR dump to
|
||
look at varies, but frequently the IR dump right before `DropShapeCalculations`
|
||
is the most useful, because it has already been simplified as much as possible,
|
||
making it is easy to see what is blocking further simplification. Examples of
|
||
issues you might see:
|
||
|
||
- You might find that there is a loop with a non-constant trip count, but based
|
||
on your understanding of the shape function, you would expect it to be
|
||
simplified to a constant trip count -- you can then look at the trip count
|
||
calculation and see if there is a missing fold or canonicalization.
|
||
|
||
- You might find that there is a list operation that is not currently understood
|
||
by the optimizations. You can then teach the optimizations about that
|
||
operation.
|
||
|
||
- You might find that there is an `Optional` value that you would expect to be
|
||
resolved to either a concrete value or `None`. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.
|
||
|
||
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
|
||
guidance on debugging Torch-MLIR.
|
||
|
||
As a last resort, you can rewrite the shape function using constructs that
|
||
`torch-simplify-shape-functions` can handle (look at other shape functions for
|
||
examples, sometimes it requires writing things a little awkwardly).
|