6.4 KiB
Design sketch: libtorch code generation round-trip
It has been brought up a couple of times that having a dynamic fallback to libtorch for kernel calls that the compiler does not recognize could be advantageous. This is a sketch of how such a facility could work.
Op background
When programs are imported from Torch (either via acap/driver capture or
from TorchScript), kernel calls are mapped to a torch.kernel_call
op, which
it is useful to visualize:
%0 = torch.kernel_call "aten::mm" %arg0, %arg1 :
(!numpy.ndarray<[2,3]:f32>, !numpy.ndarray<[3,4]:f32>) ->
!numpy.ndarray<[2,4]:f32>
{
sigArgTypes = ["Tensor", "Tensor"],
sigIsMutable = false,
sigIsVararg = false,
sigIsVarret = false,
sigRetTypes = ["Tensor"]
}
A couple of things to note at this level:
- Tensor operands/results are all represented by mutable
ndarray
types. - The kernel call name ("aten::mm" above) is the
c10::OperatorName
. sigArgTypes
andsigRetTypes
correspond to the rest of a signature. Together with the kernel name, it is sufficient to find a preciseOpHandle
that can be used for making calls.- The
torch.kernel_call
implements theTorchKernelOpInterface
which provides structured access to this metadata.
From here, one typically uses the pass aten-recognize-kernels
to promote
torch.kernel_call
ops that the compiler has concretely modeled into
corresponding aten
dialect ops. Here is an example of a function containing
the above, with aten kernels recognized:
func @mm(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {
%0 = numpy.copy_to_tensor %arg0 : (!numpy.ndarray<[2,3]:f32>) -> tensor<2x3xf32>
%1 = numpy.copy_to_tensor %arg1 : (!numpy.ndarray<[3,4]:f32>) -> tensor<3x4xf32>
%2 = "aten.mm"(%0, %1) : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
%3 = numpy.create_array_from_tensor %2 : (tensor<2x4xf32>) -> !numpy.ndarray<[2,4]:f32>
return %3 : !numpy.ndarray<[2,4]:f32>
}
A few things to note about this form:
- These recognized kernels are generated from the
torch_signature_ods_gen.py
script, which imposes some mapping policy on them. - Most kernels are aggressively converted to operate on ssa tensor values via
copy_to_tensor
/create_array_from_tensor
ops, making the majority of ops in theaten
dialect which are purely functional operate just on value semantic types. - The metadata is stripped off of the originating
torch.kernel_call
but eachaten
op implementsTorchKernelOpInterface
, giving it access to the kernel name and a signature matching its operands/results of a Torch kernel that implements the computation. - There is some information loss here but there should be enough retained to
perform the correct calculation, if not execute it exactly as the original
program specified (i.e.
out=
and other "ergonomic" aliases will be rewritten into dedicated stores, etc).
General fallback flow
The most straight-forward way to execute a torch.kernel_call
or aten
op
supporting the TorchKernelOpInterface
would be to rewrite it into code that
invokes the ATen boxed dispatch mechanism:
- Looking up a corresponding kernel based on a signature known at compile time
(constructed from
TorchKernelOpInterface
metadata). - For each operand, scribble into a
Stack
(ofIValue
) list. - Invoking
c10::Dispatcher::callBoxed()
with the stack. - Marshaling returned results back out of the return
Stack
. - Performing error and type constraint checking.
The "inside" of such a dispatch function would be somewhat "switchy" but is not all that complicated.
Runtime library
libtorch
on its own is not particularly amenable to be invoked from such
a low level. It would be better if there were a shared library that provided
the above facility as simple C functions that the compiler could emit calls
to. It would then be trivial to load/link this shared library in for JIT'ing,
AOT compilation, etc.
Example:
/// Looks up a Torch op given a signature.
void *refbackFindTorchOp(const char *signature);
/// Creates a 'call' struct from an op returned by `refbackFindTorchOp`.
/// Must be de-allocated via refbackDestroyTorchCall() when done.
void *refbackCreateTorchCall(void *torchOp);
/// Adds IValues to the call stack.
void refbackTorchCallAddTensor(void *call, void *data, int64_t *sizes, int rank);
void refbackTorchCallAddScalar(void *call, int64_t scalar);
// ...
/// Invokes the kernel.
/// After invocation, results can be read out with below methods.
bool refbackTorchInvoke(void *call);
/// Gets IValues from the result stack.
bool refbackTorchGetTensor(void *call, size_t index, void **data, int64_t **sizes, int *rank);
bool refbackTorchGetScalar(void *call, size_t index, int64_t *scalar);
/// Frees any resources associated with the call.
void refbackTorchCallDestroy(void *call);
Generating code
A pass could be written to transform ops implementing TorchKernelOpInterface
into llvm
calls into the above functions. Details will be a bit thick and
depend on precise representations, but it should be fully generic. It should
be possible to prototype the whole thing with nothing but command line tools
and the existing torch_mlir
paths for extracting programs.
Code location recommendations:
- C-runtime library:
frontends/pytorch/csrc/kernelcrt
- Code generation pass:
include/npcomp/Dialects/Torch/Transforms/TorchKernelToLLVMPass.cpp
Gotchas
This facility should work well for Torch kernels that are wholly unknown to
the compiler. However, kernels that the compiler fails to lower completely (i.e.
due to some unsupported, and unknown at the outset dynamism) way end up as
tcf
ops or others that cannot be natively lowered via the
TorchKernelOpInterface
facility. We can deal with this phase ordering in
a couple of ways:
- When converting into
tcf
be more precise about when certain dynamic constructs are wholly unsupported. Not likely to scale really well unless if just being used as a stop-gap. In that case, possibly having a pass early that marks ops to not lower because we know we want to retain them at the higher level may be fine. - Treat
aten
as both a source and a target dialect fortcf
: implement lowerings toaten
that run after the rest oftcf
has been lowered. - Implement
TorchKernelOpInterface
on thetcf
ops (or have some other interface for mapping them back).