mirror of https://github.com/llvm/torch-mlir
6431b0f11f
The current implementation is just sufficient to do a unary aten.tanh from the e2e spike, and just applies some local rewrite patterns. I've sketched out the more full explanation of where this pass eventually need to go in the pass docs. Adding this required adding `numpy.tensor_static_info_cast`, which is the tensor analog of `numpy.static_info_cast`. This op encapsulates the same numpy-specific "no runtime code" casting semantics, in particular the interpretation of `!numpy.any_dtype`. The `numpy.tensor_static_info_cast` I see in practice now are "information erasing" and will be removed by a later pass that exploits the fact that aten ops are agnostic to the static info in the operand types (so substituting a type with more static info is fine). Side note: we *need* to do dtype and rank inference before aten->tcf (which will eventually mostly be aten->linalg+guards), because each aten op is idiosyncratically overloaded based on dtype and rank. Without copying that idiosyncratic overloading into lower layers (layering violation), we cannot really lower it to anything until we do that. |
||
---|---|---|
.. | ||
ATen | ||
Basicpy | ||
Numpy | ||
Refback | ||
Refbackrt | ||
TCF | ||
TCP | ||
Torch |