torch-mlir/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances.mlir

69 lines
3.3 KiB
MLIR
Raw Normal View History

[torch-mlir earthmoving (1/N)] C/C++ code movement. This creates the `external/torch-mlir` directory as an LLVM_EXTERNAL_PROJECTS-compatible project (analogous to `iree-dialects`) and completes movement/rename of all pure MLIR C/C++ compiler code into there. The next step will be to move all the Python code / code that links/includes PyTorch C++ code (which currently lives in `frontends/pytorch`) into a subdirectory here. I call this "earthmoving" because it is mostly mechanical changes and renames. As a quick summary (we can change this down the road easily) - C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch` - CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet` - preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_` - CMake `NPCOMPFoo -> TorchMLIRFoo` The goal of this is to create a standalone project creating a center of mass for entry into the MLIR ecosystem from PyTorch, suitable in scope for eventual inclusion/ownership in PyTorch. The idea is that `external/torch-mlir` will some day be pulled out into its own repository, and then npcomp will simply pull it in as a submodule. Layering-wise, what lives in `torch-mlir` lowers code from PyTorch (currently TorchScript, but TorchFX or pytorch/xla-style tracing are possible extensions) down to what we have been calling the "Torch backend contract" which is cleaned up IR (inlining, simplifcation, conversion to value tensors, ...) entirely in the `torch` dialect. This is the branching off point for further lowering, of which npcomp takes one opinion (outside `torch-mlir` of course!), namely the `TorchConversion` dialect/transforms which lower to IR suitable for IREE and other linalg-on-tensors based lower-level compilers. Summary of changes: - move `{include,lib,test}/Dialect/Torch` into `torch-mlir` - move relevant parts of CAPI into `torch-mlir`. - leave a few things related to the `torch-mlir` Python build commented out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
torch.class_type @__torch__.TestModule {
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
torch.attr private "s2" : !torch.nn.Module<"__torch__.Submodule">
torch.method "forward", @__torch__.TestModule.forward
}
torch.class_type @__torch__.Submodule {
torch.attr private "n" : !torch.int
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
torch.method private "forward", @__torch__.Submodule.forward
}
Rework how global slot initializers work. Rather than a per-global-slot initializer region, we now have one for the whole module. For example, it might look like this: ``` torch.global_slot "private" @tensor : !torch.tensor torch.global_slot "private" @list : !torch.list<tensor> torch.global_slot.module_initializer { %0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor> torch.initialize.global_slots [ @tensor(%0 : !torch.tensor) @list(%1 : !torch.list<tensor>) ] } ``` This new structure allows GlobalizeObjectGraph to create the initializer in a much simpler way, avoiding the need to reason about whether different slots alias each other. Reasoning about whether slots alias each other now is the responsibility of InlineGlobalSlots, which has to do a much more complicated analysis, implemented using MLIR's dataflow analysis framework. Recommended review order: - Check out the new IR constructs in the .mlir files of various passes - Op definitions (*.td) - Changes to GlobalizeObjectGraph pass. - InlineGlobalSlots pass (~total rewrite) - Misc changes: - Moving torchMlirAdjustStaticInformation for sharing with C++ code. - EraseModuleInitializer pass To make this a bit nicer, it would be good to have a `torch.module` op with an initializer region attached. That would be more invasive though. This change has highlighted certain aspects of our project layering which are worth calling out. None of our backends can handle global slots, so we enforce that there are no global slots before backend lowering. At an earlier stage in the project, we had aspirations of transparently handling mutable global state and such, but for reasons described below, that is no longer a goal. So really global slots should be seen as a progressive lowering step as part of inlining all the IValue's in the original program (GlobalizeObjectGraph is also one such step). Over time, with insights from work like IREE-JAX, it has become clear that there isn't a reliable programming model we can compile for users where we just transparently handle mutable global state (and some other things, like lists and dictionaries). There is a need for an "outer program" that orchestrates more restricted subroutines of the kind we can handle in our compile flow here. The benefit of that is that it decouples considerations like shapes, dtypes, etc. from the program constructs used in the outer program. As long as the outer program can efficiently invoke (pipelining/async/etc.) high-performance data-parallel numerical subroutines of the kind we compile in our flow here, then there is a complete programming model. This is also consistent with the direction of upstream PyTorch which is becoming more tracing-based (which inherently loses a lot of program structure, which then has to be applied back with an "outer program" orchestrating the traced subroutines).
2022-07-14 02:45:56 +08:00
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: torch.initialize.global_slots [
// CHECK: @s1.n(%[[INT1]] : !torch.int)
// CHECK: @s2.n(%[[INT2]] : !torch.int)
// CHECK: ]
// CHECK: }
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int
%int1 = torch.constant.int 1
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
%s1 = torch.nn_module {
torch.slot "n", %int1 : !torch.int
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
} : !torch.nn.Module<"__torch__.Submodule">
%int2 = torch.constant.int 2
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
%s2 = torch.nn_module {
torch.slot "n", %int2 : !torch.int
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
} : !torch.nn.Module<"__torch__.Submodule">
%3 = torch.nn_module {
torch.slot "s1", %s1 : !torch.nn.Module<"__torch__.Submodule">
torch.slot "s2", %s2 : !torch.nn.Module<"__torch__.Submodule">
} : !torch.nn.Module<"__torch__.TestModule">
// CHECK-LABEL: func.func @forward() {
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
// CHECK: call @s1.forward() : () -> ()
// CHECK: call @s2.forward() : () -> ()
// CHECK: return
func.func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
call @__torch__.Submodule.forward(%4) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
call @__torch__.Submodule.forward(%5) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
return
}
// CHECK-LABEL: func.func private @s1.forward() {
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[N:.*]] = torch.global_slot.get @s1.n : !torch.int
// CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int
// CHECK: torch.global_slot.set @s1.n = %[[NEWVAL]] : !torch.int
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
// CHECK: return
// CHECK-LABEL: func.func private @s2.forward() {
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[N:.*]] = torch.global_slot.get @s2.n : !torch.int
// CHECK: %[[NEWVAL:.*]] = torch.aten.add.int %[[N]], %[[C3]] : !torch.int, !torch.int -> !torch.int
// CHECK: torch.global_slot.set @s2.n = %[[NEWVAL]] : !torch.int
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
// CHECK: return
func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
%int3 = torch.constant.int 3
%5 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int
%6 = torch.aten.add.int %5, %int3 : !torch.int, !torch.int -> !torch.int
torch.prim.SetAttr %arg0["n"] = %6 : !torch.nn.Module<"__torch__.Submodule">, !torch.int
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
return
}