mirror of https://github.com/llvm/torch-mlir
151 lines
6.4 KiB
Markdown
151 lines
6.4 KiB
Markdown
|
# Design sketch: libtorch code generation round-trip
|
||
|
|
||
|
It has been brought up a couple of times that having a dynamic fallback to
|
||
|
libtorch for kernel calls that the compiler does not recognize could be
|
||
|
advantageous. This is a sketch of how such a facility could work.
|
||
|
|
||
|
## Op background
|
||
|
|
||
|
When programs are imported from Torch (either via acap/driver capture or
|
||
|
from TorchScript), kernel calls are mapped to a `torch.kernel_call` op, which
|
||
|
it is useful to visualize:
|
||
|
|
||
|
```mlir
|
||
|
%0 = torch.kernel_call "aten::mm" %arg0, %arg1 :
|
||
|
(!numpy.ndarray<[2,3]:f32>, !numpy.ndarray<[3,4]:f32>) ->
|
||
|
!numpy.ndarray<[2,4]:f32>
|
||
|
{
|
||
|
sigArgTypes = ["Tensor", "Tensor"],
|
||
|
sigIsMutable = false,
|
||
|
sigIsVararg = false,
|
||
|
sigIsVarret = false,
|
||
|
sigRetTypes = ["Tensor"]
|
||
|
}
|
||
|
```
|
||
|
|
||
|
A couple of things to note at this level:
|
||
|
|
||
|
* Tensor operands/results are all represented by mutable `ndarray` types.
|
||
|
* The kernel call name ("aten::mm" above) is the `c10::OperatorName`.
|
||
|
* `sigArgTypes` and `sigRetTypes` correspond to the rest of a signature.
|
||
|
Together with the kernel name, it is sufficient to find a precise `OpHandle`
|
||
|
that can be used for making calls.
|
||
|
* The `torch.kernel_call` implements the `TorchKernelOpInterface` which
|
||
|
provides structured access to this metadata.
|
||
|
|
||
|
From here, one typically uses the pass `aten-recognize-kernels` to promote
|
||
|
`torch.kernel_call` ops that the compiler has concretely modeled into
|
||
|
corresponding `aten` dialect ops. Here is an example of a function containing
|
||
|
the above, with aten kernels recognized:
|
||
|
|
||
|
```mlir
|
||
|
func @mm(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {
|
||
|
%0 = numpy.copy_to_tensor %arg0 : (!numpy.ndarray<[2,3]:f32>) -> tensor<2x3xf32>
|
||
|
%1 = numpy.copy_to_tensor %arg1 : (!numpy.ndarray<[3,4]:f32>) -> tensor<3x4xf32>
|
||
|
%2 = "aten.mm"(%0, %1) : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
|
||
|
%3 = numpy.create_array_from_tensor %2 : (tensor<2x4xf32>) -> !numpy.ndarray<[2,4]:f32>
|
||
|
return %3 : !numpy.ndarray<[2,4]:f32>
|
||
|
}
|
||
|
```
|
||
|
|
||
|
A few things to note about this form:
|
||
|
|
||
|
* These recognized kernels are generated from the `torch_signature_ods_gen.py`
|
||
|
script, which imposes some mapping policy on them.
|
||
|
* Most kernels are aggressively converted to operate on ssa tensor values via
|
||
|
`copy_to_tensor`/`create_array_from_tensor` ops, making the majority of ops
|
||
|
in the `aten` dialect which are purely functional operate just on value
|
||
|
semantic types.
|
||
|
* The metadata is stripped off of the originating `torch.kernel_call` but each
|
||
|
`aten` op implements `TorchKernelOpInterface`, giving it access to the kernel
|
||
|
name and a signature matching its operands/results of a Torch kernel that
|
||
|
implements the computation.
|
||
|
* There is some information loss here but there should be enough retained to
|
||
|
perform the correct calculation, if not execute it exactly as the original
|
||
|
program specified (i.e. `out=` and other "ergonomic" aliases will be
|
||
|
rewritten into dedicated stores, etc).
|
||
|
|
||
|
## General fallback flow
|
||
|
|
||
|
The most straight-forward way to execute a `torch.kernel_call` or `aten` op
|
||
|
supporting the `TorchKernelOpInterface` would be to rewrite it into code that
|
||
|
invokes the ATen boxed dispatch mechanism:
|
||
|
|
||
|
* Looking up a corresponding kernel based on a signature known at compile time
|
||
|
(constructed from `TorchKernelOpInterface` metadata).
|
||
|
* For each operand, scribble into a `Stack` (of `IValue`) list.
|
||
|
* Invoking `c10::Dispatcher::callBoxed()` with the stack.
|
||
|
* Marshaling returned results back out of the return `Stack`.
|
||
|
* Performing error and type constraint checking.
|
||
|
|
||
|
The "inside" of such a dispatch function would be somewhat "switchy" but is
|
||
|
not all that complicated.
|
||
|
|
||
|
## Runtime library
|
||
|
|
||
|
`libtorch` on its own is not particularly amenable to be invoked from such
|
||
|
a low level. It would be better if there were a shared library that provided
|
||
|
the above facility as simple C functions that the compiler could emit calls
|
||
|
to. It would then be trivial to load/link this shared library in for JIT'ing,
|
||
|
AOT compilation, etc.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```c
|
||
|
/// Looks up a Torch op given a signature.
|
||
|
void *refbackFindTorchOp(const char *signature);
|
||
|
|
||
|
/// Creates a 'call' struct from an op returned by `refbackFindTorchOp`.
|
||
|
/// Must be de-allocated via refbackDestroyTorchCall() when done.
|
||
|
void *refbackCreateTorchCall(void *torchOp);
|
||
|
|
||
|
/// Adds IValues to the call stack.
|
||
|
void refbackTorchCallAddTensor(void *call, void *data, int64_t *sizes, int rank);
|
||
|
void refbackTorchCallAddScalar(void *call, int64_t scalar);
|
||
|
// ...
|
||
|
|
||
|
/// Invokes the kernel.
|
||
|
/// After invocation, results can be read out with below methods.
|
||
|
bool refbackTorchInvoke(void *call);
|
||
|
|
||
|
/// Gets IValues from the result stack.
|
||
|
bool refbackTorchGetTensor(void *call, size_t index, void **data, int64_t **sizes, int *rank);
|
||
|
bool refbackTorchGetScalar(void *call, size_t index, int64_t *scalar);
|
||
|
|
||
|
/// Frees any resources associated with the call.
|
||
|
void refbackTorchCallDestroy(void *call);
|
||
|
```
|
||
|
|
||
|
## Generating code
|
||
|
|
||
|
A pass could be written to transform ops implementing `TorchKernelOpInterface`
|
||
|
into `llvm` calls into the above functions. Details will be a bit thick and
|
||
|
depend on precise representations, but it should be fully generic. It should
|
||
|
be possible to prototype the whole thing with nothing but command line tools
|
||
|
and the existing `torch_mlir` paths for extracting programs.
|
||
|
|
||
|
## Code location recommendations:
|
||
|
|
||
|
* C-runtime library: `frontends/pytorch/csrc/kernelcrt`
|
||
|
* Code generation pass: `include/npcomp/Dialects/Torch/Transforms/TorchKernelToLLVMPass.cpp`
|
||
|
|
||
|
## Gotchas
|
||
|
|
||
|
This facility should work well for Torch kernels that are wholly unknown to
|
||
|
the compiler. However, kernels that the compiler fails to lower completely (i.e.
|
||
|
due to some unsupported, and unknown at the outset dynamism) way end up as
|
||
|
`tcf` ops or others that cannot be natively lowered via the
|
||
|
`TorchKernelOpInterface` facility. We can deal with this phase ordering in
|
||
|
a couple of ways:
|
||
|
|
||
|
* When converting into `tcf` be more precise about when certain dynamic
|
||
|
constructs are wholly unsupported. Not likely to scale really well unless
|
||
|
if just being used as a stop-gap. In that case, possibly having a pass
|
||
|
early that marks ops to not lower because we know we want to retain them
|
||
|
at the higher level may be fine.
|
||
|
* Treat `aten` as both a source and a target dialect for `tcf`: implement
|
||
|
lowerings *to* `aten` that run after the rest of `tcf` has been lowered.
|
||
|
* Implement `TorchKernelOpInterface` on the `tcf` ops (or have some other
|
||
|
interface for mapping them back).
|
||
|
|