torch-mlir/docs/adding_abstract_interpretat...

5.0 KiB
Raw Blame History

Adding Abstract Interpretation Functions

Overview

As part of adding support for a Torch operator in Torch-MLIR, it is usually necessary to define a shape and dtype function so that the compiler can infer the shapes and dtypes of result tensors for the operator. We use the abstract interpretation library for this process.

Step-by-step guide

We will use the example of adding support for the torch.aten.tanh op.

  1. First, you need to find the shape and dtype function signatures for the operator you are implementing a functions for. This can be found in include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt generated by the build_tools/update_torch_ods.sh script. That file is the "rosetta stone" that allows translating between e.g. torch.aten.tanh, AtenTanhOp, and the shape and dtype function signatures are:

    • def atentanh〡shape(self: List[int]) -> List[int]:
    • def atentanh〡dtype(self_rank: int, self_dtype: int) -> int:

    Note the use of as a separator since . or :: aren't legal in a Python identifier.

  2. Paste the function signature into abstract_interp_lib_gen.py in an appropriate place (ideally near other functions with a similar functions). Note that abstract_interp_lib_gen.py will check that these signatures are verbatim identical with the ones given in JITOperatorRegistryDump.txt -- this ensures that the functions don't get outdated if Torch changes an operator signature.

  3. Fill in the body of the function. Ideally this will just be a call into a helper function from torch/jit/_shape_functions.py. But in general, you will need to write the function and test it (see the comments about "Shape, dtype, and decomposition function testing infrastructure" in testing_framework.py). New shape functions should be added upstream following the example of this PR, though it can be useful to iterate locally in abstract_interp_lib_gen.py first.

    Similarly, dtype functions should ideally just be a call to the helper promote_dtypes defined in library_generator.py. However, some ops will require some extra logic to calculate the right result types. While dtypes are expressed as ints in the arguments of the dtype function, using PyTorch dtypes, such as torch.int and torch.float32, in the body of the dtype function is fully supported. Dtype functions are also expected to be fully tested.

  4. Re-run the build_tools/update_abstract_interp_lib.sh script to update the library. After this step happens, ideally everything "just works" and the functions are now correctly inferred for the operator.

When things go wrong

It is possible that the refinement pipeline (see Shape and Dtype Refinement Pipeline Architecture) is not able to infer the shape or dtype of a tensor with a given abstract interpretation function. This usually means that there is something about the function which the optimizations in torch-simplify-shape-functions and torch-simplify-dtype-functions (lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp, lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp) cannot handle.

To debug this, the overall goal is to pinpoint the IR construct that is not being simplified. This is usually accomplished by a combination of looking at the Python code for the function and the IR dumps. The best IR dump to look at varies, but frequently the IR dump right before DropAbstractInterpCalculations is the most useful, because it has already been simplified as much as possible, making it is easy to see what is blocking further simplification. Examples of issues you might see:

  • You might find that there is a loop with a non-constant trip count, but based on your understanding of the function, you would expect it to be simplified to a constant trip count -- you can then look at the trip count calculation and see if there is a missing fold or canonicalization.

  • You might find that there is a list operation that is not currently understood by the optimizations. You can then teach the optimizations about that operation.

  • You might find that there is an Optional value that you would expect to be resolved to either a concrete value or None. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.

See this video for general guidance on debugging Torch-MLIR.

As a last resort, you can rewrite the function using constructs that torch-simplify-shape-functions and torch-simplify-dtype-functions can handle (look at other functions for examples, sometimes it requires writing things a little awkwardly).