To debug this, the overall goal is to pinpoint the IR construct that is not
being simplified. This is usually accomplished by a combination of looking at
the Python code for the shape function and the IR dumps. The best IR dump to
look at varies, but frequently the IR dump right before `DropShapeCalculations`
is the most useful, because it has already been simplified as much as possible,
making it is easy to see what is blocking further simplification. Examples of
issues you might see:
- You might find that there is a loop with a non-constant trip count, but based
on your understanding of the shape function, you would expect it to be
simplified to a constant trip count -- you can then look at the trip count
calculation and see if there is a missing fold or canonicalization.
- You might find that there is a list operation that is not currently understood
by the optimizations. You can then teach the optimizations about that
operation.
- You might find that there is an `Optional` value that you would expect to be
resolved to either a concrete value or `None`. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
guidance on debugging Torch-MLIR.
As a last resort, you can rewrite the shape function using constructs that
`torch-simplify-shape-functions` can handle (look at other shape functions for
examples, sometimes it requires writing things a little awkwardly).