# Adding Abstract Interpretation Functions ## Overview As part of adding support for a Torch operator in Torch-MLIR, it is usually necessary to define a shape and dtype function so that the compiler can infer the shapes and dtypes of result tensors for the operator. We use the [abstract interpretation library](abstract_interp_lib.md) for this process. ## Step-by-step guide We will use the example of adding support for the `torch.aten.tanh` op. 1. First, you need to find the shape and dtype function signatures for the operator you are implementing a functions for. This can be found in `include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt` generated by the `build_tools/update_torch_ods.sh` script. That file is the "rosetta stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype function signatures are: - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:` Note the use of `〇` as a separator since `.` or `::` aren't legal in a Python identifier. 2. Paste the function signature into `abstract_interp_lib_gen.py` in an appropriate place (ideally near other functions with a similar functions). Note that `abstract_interp_lib_gen.py` will check that these signatures are verbatim identical with the ones given in `JITOperatorRegistryDump.txt` -- this ensures that the functions don't get outdated if Torch changes an operator signature. 3. Fill in the body of the function. Ideally this will just be a call into a helper function from [`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1). But in general, you will need to write the function and test it (see the comments about "Shape, dtype, and decomposition function testing infrastructure" in `testing_framework.py`). New shape functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889), though it can be useful to iterate locally in `abstract_interp_lib_gen.py` first. Similarly, dtype functions should ideally just be a call to the helper `promote_dtypes` defined in `library_generator.py`. However, some ops will require some extra logic to calculate the right result types. While dtypes are expressed as `int`s in the arguments of the dtype function, using PyTorch dtypes, such as `torch.int` and `torch.float32`, in the body of the dtype function is fully supported. Dtype functions are also expected to be fully tested. 4. Re-run the `build_tools/update_abstract_interp_lib.sh` script to update the library. After this step happens, ideally everything "just works" and the functions are now correctly inferred for the operator. ## When things go wrong It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](abstract_interp_lib.md#shape-and-dtype-refinement-pipeline-architecture)) is not able to infer the shape or dtype of a tensor with a given abstract interpretation function. This usually means that there is something about the function which the optimizations in `torch-simplify-shape-functions` and `torch-simplify-dtype-functions` (`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`, `lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`) cannot handle. To debug this, the overall goal is to pinpoint the IR construct that is not being simplified. This is usually accomplished by a combination of looking at the Python code for the function and the IR dumps. The best IR dump to look at varies, but frequently the IR dump right before `DropAbstractInterpCalculations` is the most useful, because it has already been simplified as much as possible, making it is easy to see what is blocking further simplification. Examples of issues you might see: - You might find that there is a loop with a non-constant trip count, but based on your understanding of the function, you would expect it to be simplified to a constant trip count -- you can then look at the trip count calculation and see if there is a missing fold or canonicalization. - You might find that there is a list operation that is not currently understood by the optimizations. You can then teach the optimizations about that operation. - You might find that there is an `Optional` value that you would expect to be resolved to either a concrete value or `None`. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing. See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general guidance on debugging Torch-MLIR. As a last resort, you can rewrite the function using constructs that `torch-simplify-shape-functions` and `torch-simplify-dtype-functions` can handle (look at other functions for examples, sometimes it requires writing things a little awkwardly).