torch-mlir/test/Dialect/Torch/globalize-object-graph-free...

30 lines
1.1 KiB
MLIR
Raw Normal View History

Properly import the entire torch::jit::CompilationUnit This primarily unlocks proper handling of free functions (that is, functions that are not methods of any torch.nn.Module). Recommended review order: - `ivalue_importer.cpp` + `ivalue_import/functions*.py` - `GlobalizeObjectGraph.cpp` + test case - misc other stuff The `torch::jit::CompilationUnit` is basically a backing store or "context" holding all the possible functions in the program. The previous code was not explicitly accessing this data structure, since it just imported the `torch::jit::Function`'s that it saw attached to methods. Subtly, any time a TorchScript module called into a free function, the free function gets incorporated into the torch::jit::CompilationUnit, but doesn't show up anywhere when dumping the module, except in the curious pattern: ``` %5 : Function = prim::Constant[name="adaptive_avg_pool2d"]() %6 : Tensor = prim::CallFunction(%5, %input.1, %4) ``` That is, calls are indirect calls, and are accessed via `prim::Constant` materializing a function object. Even stranger, the `name` attribute here doesn't really even tell the full story -- it doesn't correspond to anything. It turns out that the c10::FunctionType itself actually holds a pointer to the `torch::jit::Function` in the compilation unit directly (so there is actually no indirection in prim::CallMethod, because any two values of the same FunctionType call the same function!). E.g. when converting the IR to bytecode, the "name" is ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937). We do import `prim::CallFunction` as a `std.call_indirect` though because it's more braindead to do it that way (it gets canonicalized to a direct call easily).
2021-02-27 08:20:35 +08:00
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
torch.class_type @c {
torch.attr "float" : f64
torch.method "calls_free_function", @calls_free_function
}
// CHECK-LABEL: func private @__npcomp_priv_fn$free_function(
// CHECK-SAME: %[[F:.*]]: f64) -> f64 {
// CHECK: return %[[F]] : f64
// CHECK: }
func private @free_function(%arg0: f64, %arg1: !torch.nn.Module<"c">) -> f64 {
return %arg0 : f64
}
// CHECK-LABEL: func @calls_free_function() -> f64 {
// CHECK: %[[F1:.*]] = torch.global_slot.get @float : f64
// CHECK: %[[RET:.*]] = call @__npcomp_priv_fn$free_function(%[[F1]]) : (f64) -> f64
// CHECK: return %[[RET]] : f64
// CHECK: }
func private @calls_free_function(%arg0: !torch.nn.Module<"c">) -> f64 {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> f64
%1 = call @free_function(%0, %arg0) : (f64, !torch.nn.Module<"c">) -> f64
return %1 : f64
}
%c42 = std.constant 42.0 : f64
torch.nn_module {
torch.slot "float", %c42 : f64
} : !torch.nn.Module<"c">