torch-mlir/test/Dialect/Torch/inline-global-slots-transfo...

82 lines
2.6 KiB
MLIR
Raw Permalink Normal View History

Rework how global slot initializers work. Rather than a per-global-slot initializer region, we now have one for the whole module. For example, it might look like this: ``` torch.global_slot "private" @tensor : !torch.tensor torch.global_slot "private" @list : !torch.list<tensor> torch.global_slot.module_initializer { %0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor> torch.initialize.global_slots [ @tensor(%0 : !torch.tensor) @list(%1 : !torch.list<tensor>) ] } ``` This new structure allows GlobalizeObjectGraph to create the initializer in a much simpler way, avoiding the need to reason about whether different slots alias each other. Reasoning about whether slots alias each other now is the responsibility of InlineGlobalSlots, which has to do a much more complicated analysis, implemented using MLIR's dataflow analysis framework. Recommended review order: - Check out the new IR constructs in the .mlir files of various passes - Op definitions (*.td) - Changes to GlobalizeObjectGraph pass. - InlineGlobalSlots pass (~total rewrite) - Misc changes: - Moving torchMlirAdjustStaticInformation for sharing with C++ code. - EraseModuleInitializer pass To make this a bit nicer, it would be good to have a `torch.module` op with an initializer region attached. That would be more invasive though. This change has highlighted certain aspects of our project layering which are worth calling out. None of our backends can handle global slots, so we enforce that there are no global slots before backend lowering. At an earlier stage in the project, we had aspirations of transparently handling mutable global state and such, but for reasons described below, that is no longer a goal. So really global slots should be seen as a progressive lowering step as part of inlining all the IValue's in the original program (GlobalizeObjectGraph is also one such step). Over time, with insights from work like IREE-JAX, it has become clear that there isn't a reliable programming model we can compile for users where we just transparently handle mutable global state (and some other things, like lists and dictionaries). There is a need for an "outer program" that orchestrates more restricted subroutines of the kind we can handle in our compile flow here. The benefit of that is that it decouples considerations like shapes, dtypes, etc. from the program constructs used in the outer program. As long as the outer program can efficiently invoke (pipelining/async/etc.) high-performance data-parallel numerical subroutines of the kind we compile in our flow here, then there is a complete programming model. This is also consistent with the direction of upstream PyTorch which is becoming more tracing-based (which inherently loses a lot of program structure, which then has to be applied back with an "outer program" orchestrating the traced subroutines).
2022-07-14 02:45:56 +08:00
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// Transform aspect of the pass.
// Test case: Most basic case that can be inlined.
// CHECK-NOT: @slot0
torch.global_slot "private" @slot0 : !torch.int
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.constant.int 1
torch.initialize.global_slots [
@slot0(%0 : !torch.int)
]
}
// CHECK-LABEL: func.func @forward() {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: return
func.func @forward() {
%0 = torch.global_slot.get @slot0 : !torch.int
return
}
// -----
// Test case: Shared objects in object graph shared between two initial values.
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list_of_tensor : !torch.list<tensor>
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list_of_tensor(%1 : !torch.list<tensor>)
]
}
// CHECK-LABEL: func.func @forward() {
// CHECK: %[[T0:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor
// CHECK: %[[T1:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[T1]] : (!torch.tensor) -> !torch.list<tensor>
// CHECK: return
func.func @forward() {
%0 = torch.global_slot.get @tensor : !torch.tensor
%1 = torch.global_slot.get @list_of_tensor : !torch.tensor
return
}
// -----
// Test case: Adjusting static info.
// CHECK-NOT: @tensor
torch.global_slot "private" @tensor : !torch.tensor
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],f32>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor<[],f32>)
]
}
// CHECK-LABEL: func.func @forward() {
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor<[],f32>
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.tensor<[],f32> to !torch.tensor
func.func @forward() {
%0 = torch.global_slot.get @tensor : !torch.tensor
return
}