torch-mlir/frontends/pytorch/docs/libtorch_roundtrip.md

151 lines
6.4 KiB
Markdown
Raw Normal View History

2020-11-25 10:13:35 +08:00
# Design sketch: libtorch code generation round-trip
It has been brought up a couple of times that having a dynamic fallback to
libtorch for kernel calls that the compiler does not recognize could be
advantageous. This is a sketch of how such a facility could work.
## Op background
When programs are imported from Torch (either via acap/driver capture or
from TorchScript), kernel calls are mapped to a `torch.kernel_call` op, which
it is useful to visualize:
```mlir
%0 = torch.kernel_call "aten::mm" %arg0, %arg1 :
(!numpy.ndarray<[2,3]:f32>, !numpy.ndarray<[3,4]:f32>) ->
!numpy.ndarray<[2,4]:f32>
{
sigArgTypes = ["Tensor", "Tensor"],
sigIsMutable = false,
sigIsVararg = false,
sigIsVarret = false,
sigRetTypes = ["Tensor"]
}
```
A couple of things to note at this level:
* Tensor operands/results are all represented by mutable `ndarray` types.
* The kernel call name ("aten::mm" above) is the `c10::OperatorName`.
* `sigArgTypes` and `sigRetTypes` correspond to the rest of a signature.
Together with the kernel name, it is sufficient to find a precise `OpHandle`
that can be used for making calls.
* The `torch.kernel_call` implements the `TorchKernelOpInterface` which
provides structured access to this metadata.
From here, one typically uses the pass `aten-recognize-kernels` to promote
`torch.kernel_call` ops that the compiler has concretely modeled into
corresponding `aten` dialect ops. Here is an example of a function containing
the above, with aten kernels recognized:
```mlir
func @mm(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {
%0 = numpy.copy_to_tensor %arg0 : (!numpy.ndarray<[2,3]:f32>) -> tensor<2x3xf32>
%1 = numpy.copy_to_tensor %arg1 : (!numpy.ndarray<[3,4]:f32>) -> tensor<3x4xf32>
%2 = "aten.mm"(%0, %1) : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
%3 = numpy.create_array_from_tensor %2 : (tensor<2x4xf32>) -> !numpy.ndarray<[2,4]:f32>
return %3 : !numpy.ndarray<[2,4]:f32>
}
```
A few things to note about this form:
* These recognized kernels are generated from the `torch_signature_ods_gen.py`
script, which imposes some mapping policy on them.
* Most kernels are aggressively converted to operate on ssa tensor values via
`copy_to_tensor`/`create_array_from_tensor` ops, making the majority of ops
in the `aten` dialect which are purely functional operate just on value
semantic types.
* The metadata is stripped off of the originating `torch.kernel_call` but each
`aten` op implements `TorchKernelOpInterface`, giving it access to the kernel
name and a signature matching its operands/results of a Torch kernel that
implements the computation.
* There is some information loss here but there should be enough retained to
perform the correct calculation, if not execute it exactly as the original
program specified (i.e. `out=` and other "ergonomic" aliases will be
rewritten into dedicated stores, etc).
## General fallback flow
The most straight-forward way to execute a `torch.kernel_call` or `aten` op
supporting the `TorchKernelOpInterface` would be to rewrite it into code that
invokes the ATen boxed dispatch mechanism:
* Looking up a corresponding kernel based on a signature known at compile time
(constructed from `TorchKernelOpInterface` metadata).
* For each operand, scribble into a `Stack` (of `IValue`) list.
* Invoking `c10::Dispatcher::callBoxed()` with the stack.
* Marshaling returned results back out of the return `Stack`.
* Performing error and type constraint checking.
The "inside" of such a dispatch function would be somewhat "switchy" but is
not all that complicated.
## Runtime library
`libtorch` on its own is not particularly amenable to be invoked from such
a low level. It would be better if there were a shared library that provided
the above facility as simple C functions that the compiler could emit calls
to. It would then be trivial to load/link this shared library in for JIT'ing,
AOT compilation, etc.
Example:
```c
/// Looks up a Torch op given a signature.
void *refbackFindTorchOp(const char *signature);
/// Creates a 'call' struct from an op returned by `refbackFindTorchOp`.
/// Must be de-allocated via refbackDestroyTorchCall() when done.
void *refbackCreateTorchCall(void *torchOp);
/// Adds IValues to the call stack.
void refbackTorchCallAddTensor(void *call, void *data, int64_t *sizes, int rank);
void refbackTorchCallAddScalar(void *call, int64_t scalar);
// ...
/// Invokes the kernel.
/// After invocation, results can be read out with below methods.
bool refbackTorchInvoke(void *call);
/// Gets IValues from the result stack.
bool refbackTorchGetTensor(void *call, size_t index, void **data, int64_t **sizes, int *rank);
bool refbackTorchGetScalar(void *call, size_t index, int64_t *scalar);
/// Frees any resources associated with the call.
void refbackTorchCallDestroy(void *call);
```
## Generating code
A pass could be written to transform ops implementing `TorchKernelOpInterface`
into `llvm` calls into the above functions. Details will be a bit thick and
depend on precise representations, but it should be fully generic. It should
be possible to prototype the whole thing with nothing but command line tools
and the existing `torch_mlir` paths for extracting programs.
## Code location recommendations:
* C-runtime library: `frontends/pytorch/csrc/kernelcrt`
* Code generation pass: `include/npcomp/Dialects/Torch/Transforms/TorchKernelToLLVMPass.cpp`
## Gotchas
This facility should work well for Torch kernels that are wholly unknown to
the compiler. However, kernels that the compiler fails to lower completely (i.e.
due to some unsupported, and unknown at the outset dynamism) way end up as
`tcf` ops or others that cannot be natively lowered via the
`TorchKernelOpInterface` facility. We can deal with this phase ordering in
a couple of ways:
* When converting into `tcf` be more precise about when certain dynamic
constructs are wholly unsupported. Not likely to scale really well unless
if just being used as a stop-gap. In that case, possibly having a pass
early that marks ops to not lower because we know we want to retain them
at the higher level may be fine.
* Treat `aten` as both a source and a target dialect for `tcf`: implement
lowerings *to* `aten` that run after the rest of `tcf` has been lowered.
* Implement `TorchKernelOpInterface` on the `tcf` ops (or have some other
interface for mapping them back).