torch-mlir/docs/adding_a_shape_function.md

72 lines
3.8 KiB
Markdown
Raw Normal View History

# Adding a Shape Function
## Overview
As part of adding support for a Torch operator in Torch-MLIR, it is usually
necessary to define a shape function so that the compiler can infer the shapes
of result tensors for the operator. We use the [shape library](shape_lib.md) for this process.
## Step-by-step guide
We will use the example of adding support for the `torch.aten.tanh` op.
1. First, you need to find the shape function signature for the operator you are
implementing a shape function for. This can be found in
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt` generated
by the `build_tools/update_torch_ods.sh` script. That file is the "rosetta
stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`,
and the shape function signature
`def atentanh(self: List[int]) -> List[int]:`. Note the use of `` as a
separator since `.` or `::` aren't legal in a Python identifier.
2. Paste the shape function signature into `shape_lib_gen.py` in an appropriate
place (ideally near other functions with a similar shape function). Note that
`shape_lib_gen.py` will check that this signature is verbatim identical with
the one given in `JITOperatorRegistryDump.txt` -- this ensures that the shape
functions don't get outdated if Torch changes an operator signature.
3. Fill in the body of the shape function. Ideally this will just be a call into
a helper function from `upstream_shape_helpers.py`. But in general, you will
need to write the shape function and test it (see the comments about "Shape
function testing infrastructure" in `shape_lib_gen.py`).
4. Re-run the `build_tools/update_shape_lib.sh` script to update the shape
library. After this step happens, ideally everything "just works" and the
shape is now correctly inferred for the operator.
## When things go wrong
It is possible that the shape refinement pipeline (see
[Shape Refinement Pipeline Architecture](shape_lib.md#shape-refinement-pipeline-architecture))
is not able to infer the shape of a tensor with a given shape function. This
usually means that there is something about the shape function which the
optimizations in `torch-simplify-shape-functions`
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`) cannot handle.
To debug this, the overall goal is to pinpoint the IR construct that is not
being simplified. This is usually accomplished by a combination of looking at
the Python code for the shape function and the IR dumps. The best IR dump to
look at varies, but frequently the IR dump right before `DropShapeCalculations`
is the most useful, because it has already been simplified as much as possible,
making it is easy to see what is blocking further simplification. Examples of
issues you might see:
- You might find that there is a loop with a non-constant trip count, but based
on your understanding of the shape function, you would expect it to be
simplified to a constant trip count -- you can then look at the trip count
calculation and see if there is a missing fold or canonicalization.
- You might find that there is a list operation that is not currently understood
by the optimizations. You can then teach the optimizations about that
operation.
- You might find that there is an `Optional` value that you would expect to be
resolved to either a concrete value or `None`. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
guidance on debugging Torch-MLIR.
As a last resort, you can rewrite the shape function using constructs that
`torch-simplify-shape-functions` can handle (look at other shape functions for
examples, sometimes it requires writing things a little awkwardly).