torch-mlir/test/Dialect/Torch/prepare-for-globalize-objec...

33 lines
1.7 KiB
MLIR
Raw Normal View History

[torch-mlir earthmoving (1/N)] C/C++ code movement. This creates the `external/torch-mlir` directory as an LLVM_EXTERNAL_PROJECTS-compatible project (analogous to `iree-dialects`) and completes movement/rename of all pure MLIR C/C++ compiler code into there. The next step will be to move all the Python code / code that links/includes PyTorch C++ code (which currently lives in `frontends/pytorch`) into a subdirectory here. I call this "earthmoving" because it is mostly mechanical changes and renames. As a quick summary (we can change this down the road easily) - C++ `mlir::NPCOMP::Torch -> mlir::torch::Torch` - CAPI `npcompTorchListTypeGet -> torchMlirTorchListTypeGet` - preprocessor `#ifndef NPCOMP_ -> #ifndef TORCHMLIR_` - CMake `NPCOMPFoo -> TorchMLIRFoo` The goal of this is to create a standalone project creating a center of mass for entry into the MLIR ecosystem from PyTorch, suitable in scope for eventual inclusion/ownership in PyTorch. The idea is that `external/torch-mlir` will some day be pulled out into its own repository, and then npcomp will simply pull it in as a submodule. Layering-wise, what lives in `torch-mlir` lowers code from PyTorch (currently TorchScript, but TorchFX or pytorch/xla-style tracing are possible extensions) down to what we have been calling the "Torch backend contract" which is cleaned up IR (inlining, simplifcation, conversion to value tensors, ...) entirely in the `torch` dialect. This is the branching off point for further lowering, of which npcomp takes one opinion (outside `torch-mlir` of course!), namely the `TorchConversion` dialect/transforms which lower to IR suitable for IREE and other linalg-on-tensors based lower-level compilers. Summary of changes: - move `{include,lib,test}/Dialect/Torch` into `torch-mlir` - move relevant parts of CAPI into `torch-mlir`. - leave a few things related to the `torch-mlir` Python build commented out, which should be resolved in a subsequent change.
2021-09-10 03:24:10 +08:00
// RUN: torch-mlir-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
torch.class_type @c {
torch.method "test_call_method", @test_call_method
torch.method "test_call_indirect", @test_call_indirect
}
// CHECK-LABEL: func private @test_call_method(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[RET:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK: return %[[RET]] : !torch.float
func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = torch.prim.CallMethod %arg0["test_call_method"] (%arg1) : !torch.nn.Module<"c">, (!torch.float) -> !torch.float
return %0 : !torch.float
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
}
// CHECK-LABEL: func private @test_call_indirect(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
// Ensure no std.constant.
// CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK-NEXT: return %[[VAL_2]] : !torch.float
func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
%0 = constant @test_call_method : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
%1 = call_indirect %0(%arg0, %arg1) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
return %1 : !torch.float
Support multiple instances of a class in GlobalizeObjectGraph. This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8
2021-03-10 12:33:21 +08:00
}
torch.nn_module {
} : !torch.nn.Module<"c">