torch-mlir/docs/adding_abstract_interpretat...

100 lines
4.9 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Adding Abstract Interpretation Functions
## Overview
As part of adding support for a Torch operator in Torch-MLIR, it is usually
necessary to define a shape and dtype function so that the compiler can infer
the shapes and dtypes of result tensors for the operator. We use the
[abstract interpretation library](abstract_interp_lib.md) for this process.
## Step-by-step guide
We will use the example of adding support for the `torch.aten.tanh` op.
1. First, you need to find the shape and dtype function signatures for
the operator you are implementing a functions for. This can be
found in
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt`
generated by the `build_tools/update_torch_ods.sh` script. That
file is the "rosetta stone" that allows translating between
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype
function signatures are:
- `def atentanh〡shape(self: List[int]) -> List[int]:`
- `def atentanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:`
Note the use of `` as a separator since `.` or `::` aren't legal
in a Python identifier.
2. Paste the function signature into `abstract_interp_lib_gen.py` in an
appropriate place (ideally near other functions with a similar
functions). Note that `abstract_interp_lib_gen.py` will check that
these signatures are verbatim identical with the ones given in
`JITOperatorRegistryDump.txt` -- this ensures that the functions
don't get outdated if Torch changes an operator signature.
3. Fill in the body of the function. Ideally this will just be a call
into a helper function from
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
But in general, you will need to write the function and test it
(see the comments about "Shape, dtype, and decomposition function
testing infrastructure" in `testing_framework.py`). New shape
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
though it can be useful to iterate locally in `abstract_interp_lib_gen.py`
first.
Similarly, dtype functions should ideally just be a call to the helper
`promote_dtypes` defined in `library_generator.py`. However, some ops will
require some extra logic to calculate the right result types. While dtypes
are expressed as `int`s in the arguments of the dtype function, using PyTorch
dtypes, such as `torch.int` and `torch.float32`, in the body of the dtype
function is fully supported. Dtype functions are also expected to be fully
tested.
4. Re-run the `build_tools/update_abstract_interp_lib.sh` script to
update the library. After this step happens, ideally everything
"just works" and the functions are now correctly inferred for the
operator.
## When things go wrong
It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](abstract_interp_lib.md#shape-and-dtype-refinement-pipeline-architecture))
is not able to infer the shape or dtype of a tensor with a given
abstract interpretation function. This usually means that there is something
about the function which the optimizations in
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions`
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`,
`lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`)
cannot handle.
To debug this, the overall goal is to pinpoint the IR construct that is not
being simplified. This is usually accomplished by a combination of looking at
the Python code for the function and the IR dumps. The best IR dump to look at
varies, but frequently the IR dump right before `DropAbstractInterpCalculations`
is the most useful, because it has already been simplified as much as possible,
making it is easy to see what is blocking further simplification. Examples of
issues you might see:
- You might find that there is a loop with a non-constant trip count,
but based on your understanding of the function, you would expect it
to be simplified to a constant trip count -- you can then look at
the trip count calculation and see if there is a missing fold or
canonicalization.
- You might find that there is a list operation that is not currently understood
by the optimizations. You can then teach the optimizations about that
operation.
- You might find that there is an `Optional` value that you would
expect to be resolved to either a concrete value or `None`. You can
then look at the calculation that produces the optional value and
see what folds or canonicalizations are missing.
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
guidance on debugging Torch-MLIR.
As a last resort, you can rewrite the function using constructs that
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions` can handle
(look at other functions for examples, sometimes it requires writing things a
little awkwardly).