mirror of https://github.com/llvm/torch-mlir
100 lines
5.0 KiB
Markdown
100 lines
5.0 KiB
Markdown
|
# Adding Abstract Interpretation Functions
|
|||
|
|
|||
|
## Overview
|
|||
|
|
|||
|
As part of adding support for a Torch operator in Torch-MLIR, it is usually
|
|||
|
necessary to define a shape and dtype function so that the compiler can infer
|
|||
|
the shapes and dtypes of result tensors for the operator. We use the
|
|||
|
[abstract interpretation library](abstract_interp_lib.md) for this process.
|
|||
|
|
|||
|
## Step-by-step guide
|
|||
|
|
|||
|
We will use the example of adding support for the `torch.aten.tanh` op.
|
|||
|
|
|||
|
1. First, you need to find the shape and dtype function signatures for
|
|||
|
the operator you are implementing a functions for. This can be
|
|||
|
found in
|
|||
|
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt`
|
|||
|
generated by the `build_tools/update_torch_ods.sh` script. That
|
|||
|
file is the "rosetta stone" that allows translating between
|
|||
|
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype
|
|||
|
function signatures are:
|
|||
|
|
|||
|
- `def aten〇tanh〡shape(self: List[int]) -> List[int]:`
|
|||
|
- `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:`
|
|||
|
|
|||
|
Note the use of `〇` as a separator since `.` or `::` aren't legal
|
|||
|
in a Python identifier.
|
|||
|
|
|||
|
2. Paste the function signature into `abstract_interp_lib_gen.py` in an
|
|||
|
appropriate place (ideally near other functions with a similar
|
|||
|
functions). Note that `abstract_interp_lib_gen.py` will check that
|
|||
|
these signatures are verbatim identical with the ones given in
|
|||
|
`JITOperatorRegistryDump.txt` -- this ensures that the functions
|
|||
|
don't get outdated if Torch changes an operator signature.
|
|||
|
|
|||
|
3. Fill in the body of the function. Ideally this will just be a call
|
|||
|
into a helper function from
|
|||
|
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
|
|||
|
But in general, you will need to write the function and test it
|
|||
|
(see the comments about "Shape, dtype, and decomposition function
|
|||
|
testing infrastructure" in `testing_framework.py`). New shape
|
|||
|
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
|
|||
|
though it can be useful to iterate locally in `abstract_interp_lib_gen.py`
|
|||
|
first.
|
|||
|
|
|||
|
Similarly, dtype functions should ideally just be a call to the helper
|
|||
|
`promote_dtypes` defined in `library_generator.py`. However, some ops will
|
|||
|
require some extra logic to calculate the right result types. While dtypes
|
|||
|
are expressed as `int`s in the arguments of the dtype function, using PyTorch
|
|||
|
dtypes, such as `torch.int` and `torch.float32`, in the body of the dtype
|
|||
|
function is fully supported. Dtype functions are also expected to be fully
|
|||
|
tested.
|
|||
|
|
|||
|
4. Re-run the `build_tools/update_abstract_interp_lib.sh` script to
|
|||
|
update the library. After this step happens, ideally everything
|
|||
|
"just works" and the functions are now correctly inferred for the
|
|||
|
operator.
|
|||
|
|
|||
|
## When things go wrong
|
|||
|
|
|||
|
It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](abstract_interp_lib.md#shape-and-dtype-refinement-pipeline-architecture))
|
|||
|
is not able to infer the shape or dtype of a tensor with a given
|
|||
|
abstract interpretation function. This usually means that there is something
|
|||
|
about the function which the optimizations in
|
|||
|
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions`
|
|||
|
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`,
|
|||
|
`lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`)
|
|||
|
cannot handle.
|
|||
|
|
|||
|
To debug this, the overall goal is to pinpoint the IR construct that is not
|
|||
|
being simplified. This is usually accomplished by a combination of looking at
|
|||
|
the Python code for the function and the IR dumps. The best IR dump to look at
|
|||
|
varies, but frequently the IR dump right before `DropAbstractInterpCalculations`
|
|||
|
is the most useful, because it has already been simplified as much as possible,
|
|||
|
making it is easy to see what is blocking further simplification. Examples of
|
|||
|
issues you might see:
|
|||
|
|
|||
|
- You might find that there is a loop with a non-constant trip count,
|
|||
|
but based on your understanding of the function, you would expect it
|
|||
|
to be simplified to a constant trip count -- you can then look at
|
|||
|
the trip count calculation and see if there is a missing fold or
|
|||
|
canonicalization.
|
|||
|
|
|||
|
- You might find that there is a list operation that is not currently understood
|
|||
|
by the optimizations. You can then teach the optimizations about that
|
|||
|
operation.
|
|||
|
|
|||
|
- You might find that there is an `Optional` value that you would
|
|||
|
expect to be resolved to either a concrete value or `None`. You can
|
|||
|
then look at the calculation that produces the optional value and
|
|||
|
see what folds or canonicalizations are missing.
|
|||
|
|
|||
|
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
|
|||
|
guidance on debugging Torch-MLIR.
|
|||
|
|
|||
|
As a last resort, you can rewrite the function using constructs that
|
|||
|
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions` can handle
|
|||
|
(look at other functions for examples, sometimes it requires writing things a
|
|||
|
little awkwardly).
|