mirror of https://github.com/llvm/torch-mlir
100 lines
4.9 KiB
Markdown
100 lines
4.9 KiB
Markdown
# Adding Abstract Interpretation Functions
|
||
|
||
## Overview
|
||
|
||
As part of adding support for a Torch operator in Torch-MLIR, it is usually
|
||
necessary to define a shape and dtype function so that the compiler can infer
|
||
the shapes and dtypes of result tensors for the operator. We use the
|
||
[abstract interpretation library](abstract_interp_lib.md) for this process.
|
||
|
||
## Step-by-step guide
|
||
|
||
We will use the example of adding support for the `torch.aten.tanh` op.
|
||
|
||
1. First, you need to find the shape and dtype function signatures for
|
||
the operator you are implementing a functions for. This can be
|
||
found in
|
||
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt`
|
||
generated by the `build_tools/update_torch_ods.sh` script. That
|
||
file is the "rosetta stone" that allows translating between
|
||
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype
|
||
function signatures are:
|
||
|
||
- `def aten〇tanh〡shape(self: List[int]) -> List[int]:`
|
||
- `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:`
|
||
|
||
Note the use of `〇` as a separator since `.` or `::` aren't legal
|
||
in a Python identifier.
|
||
|
||
2. Paste the function signature into `abstract_interp_lib_gen.py` in an
|
||
appropriate place (ideally near other functions with a similar
|
||
functions). Note that `abstract_interp_lib_gen.py` will check that
|
||
these signatures are verbatim identical with the ones given in
|
||
`JITOperatorRegistryDump.txt` -- this ensures that the functions
|
||
don't get outdated if Torch changes an operator signature.
|
||
|
||
3. Fill in the body of the function. Ideally this will just be a call
|
||
into a helper function from
|
||
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
|
||
But in general, you will need to write the function and test it
|
||
(see the comments about "Shape, dtype, and decomposition function
|
||
testing infrastructure" in `testing_framework.py`). New shape
|
||
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
|
||
though it can be useful to iterate locally in `abstract_interp_lib_gen.py`
|
||
first.
|
||
|
||
Similarly, dtype functions should ideally just be a call to the helper
|
||
`promote_dtypes` defined in `library_generator.py`. However, some ops will
|
||
require some extra logic to calculate the right result types. While dtypes
|
||
are expressed as `int`s in the arguments of the dtype function, using PyTorch
|
||
dtypes, such as `torch.int` and `torch.float32`, in the body of the dtype
|
||
function is fully supported. Dtype functions are also expected to be fully
|
||
tested.
|
||
|
||
4. Re-run the `build_tools/update_abstract_interp_lib.sh` script to
|
||
update the library. After this step happens, ideally everything
|
||
"just works" and the functions are now correctly inferred for the
|
||
operator.
|
||
|
||
## When things go wrong
|
||
|
||
It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](abstract_interp_lib.md#shape-and-dtype-refinement-pipeline-architecture))
|
||
is not able to infer the shape or dtype of a tensor with a given
|
||
abstract interpretation function. This usually means that there is something
|
||
about the function which the optimizations in
|
||
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions`
|
||
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`,
|
||
`lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`)
|
||
cannot handle.
|
||
|
||
To debug this, the overall goal is to pinpoint the IR construct that is not
|
||
being simplified. This is usually accomplished by a combination of looking at
|
||
the Python code for the function and the IR dumps. The best IR dump to look at
|
||
varies, but frequently the IR dump right before `DropAbstractInterpCalculations`
|
||
is the most useful, because it has already been simplified as much as possible,
|
||
making it is easy to see what is blocking further simplification. Examples of
|
||
issues you might see:
|
||
|
||
- You might find that there is a loop with a non-constant trip count,
|
||
but based on your understanding of the function, you would expect it
|
||
to be simplified to a constant trip count -- you can then look at
|
||
the trip count calculation and see if there is a missing fold or
|
||
canonicalization.
|
||
|
||
- You might find that there is a list operation that is not currently understood
|
||
by the optimizations. You can then teach the optimizations about that
|
||
operation.
|
||
|
||
- You might find that there is an `Optional` value that you would
|
||
expect to be resolved to either a concrete value or `None`. You can
|
||
then look at the calculation that produces the optional value and
|
||
see what folds or canonicalizations are missing.
|
||
|
||
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
|
||
guidance on debugging Torch-MLIR.
|
||
|
||
As a last resort, you can rewrite the function using constructs that
|
||
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions` can handle
|
||
(look at other functions for examples, sometimes it requires writing things a
|
||
little awkwardly).
|