2021-02-18 03:28:51 +08:00
|
|
|
//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
|
|
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
|
|
|
#include "llvm/ADT/MapVector.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::NPCOMP;
|
|
|
|
using namespace mlir::NPCOMP::Torch;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// See the pass documentation for `torch-globalize-object-graph`.
|
|
|
|
class ObjectGraphGlobalizer {
|
|
|
|
public:
|
|
|
|
ObjectGraphGlobalizer(ModuleOp module);
|
|
|
|
LogicalResult globalizeObjectGraph();
|
|
|
|
|
|
|
|
private:
|
|
|
|
FailureOr<NnModuleOp> findRootNnModule();
|
|
|
|
LogicalResult checkSingleInstanceOfEachClass();
|
|
|
|
LogicalResult recursivelyTraverseClassType(ClassTypeOp classType);
|
2021-02-26 07:54:51 +08:00
|
|
|
LogicalResult populateGlobalSlotInitializer(GlobalSlotOp op,
|
|
|
|
Value initialValue);
|
2021-02-18 03:28:51 +08:00
|
|
|
LogicalResult rewriteMethods();
|
|
|
|
void removeObjectGraph();
|
|
|
|
|
|
|
|
ModuleOp module;
|
|
|
|
SymbolTable symbolTable;
|
|
|
|
OpBuilder globalBuilder;
|
|
|
|
// The stack of attribute names we have traversed during our recursive
|
|
|
|
// traversal of the class/object hierarchy.
|
|
|
|
//
|
|
|
|
// Linkage names are calculated based on the set of attribute names traversed
|
|
|
|
// from the root class/module in the program.
|
|
|
|
SmallVector<std::string> nameStack;
|
|
|
|
|
|
|
|
// Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot
|
|
|
|
// ops. However, usually it is better to keep a map keyed on an ClassTypeOp
|
|
|
|
// + attr name since frequently that is all one has access to and it
|
|
|
|
// would be tedious to scan the body of the ClassTypeOp for the torch.attr
|
|
|
|
// with the corresponding name.
|
|
|
|
using AttrOfClass =
|
|
|
|
std::pair</*ClassTypeOp*/ Operation *, /*attr name*/ StringRef>;
|
|
|
|
// The initial value associated with an attribute of a class.
|
|
|
|
// Since we only allow a single instance of a class, this is equivalent to
|
|
|
|
// the initial value of the unique slot corresponding to that attr.
|
|
|
|
DenseMap<AttrOfClass, Value> slotInitialValues;
|
|
|
|
// The inverse map of `slotInitialValues`.
|
|
|
|
// Many attributes can have the same initial value, so the value type
|
|
|
|
// is a vector.
|
|
|
|
DenseMap<Value, std::vector<AttrOfClass>> slotInitialValuesInverseMap;
|
|
|
|
|
|
|
|
// The torch.global_slot corresponding to each torch.attr/torch.slot.
|
|
|
|
DenseMap<AttrOfClass, GlobalSlotOp> globalSlotForAttr;
|
|
|
|
// The linkage name (value) for the function with symbol name (key).
|
|
|
|
DenseMap<StringRef, std::string> methodLinkageNames;
|
|
|
|
|
|
|
|
// The set of class types that have already been processed.
|
|
|
|
// Used for diagnostics.
|
|
|
|
// The map value is the original path from the root that we found it at.
|
|
|
|
DenseMap</*ClassTypeOp*/ Operation *, std::string> seenClassTypes;
|
2021-02-26 07:54:51 +08:00
|
|
|
|
|
|
|
// A set of values that we have copied into torch.global_slot initializers,
|
|
|
|
// which cannot be used in multiple initializers because their object
|
|
|
|
// identity is important.
|
|
|
|
DenseSet<Value> objectsWithIdentityAlreadyCopiedIntoInitializers;
|
2021-02-18 03:28:51 +08:00
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module)
|
|
|
|
: module(module), symbolTable(module),
|
|
|
|
globalBuilder(module.getBodyRegion()) {}
|
|
|
|
|
|
|
|
LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() {
|
|
|
|
// We require there to be a unique root !torch.nn.Module.
|
|
|
|
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule();
|
|
|
|
if (failed(maybeRootNnModule))
|
|
|
|
return failure();
|
|
|
|
NnModuleOp rootNnModule = *maybeRootNnModule;
|
|
|
|
if (!rootNnModule)
|
|
|
|
return module.emitError()
|
|
|
|
<< "module does not contain a root torch.nn_module";
|
|
|
|
|
|
|
|
// We require one instance of each class. That is, there is a single
|
|
|
|
// torch.nn_module for each torch.class_type.
|
|
|
|
if (failed(checkSingleInstanceOfEachClass()))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
for (NnModuleOp nnModule : module.getOps<NnModuleOp>()) {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(nnModule.getClassName());
|
|
|
|
for (auto slot : nnModule.getOps<SlotOp>()) {
|
|
|
|
AttrOfClass attrOfClass = {classType, slot.name()};
|
|
|
|
slotInitialValues[attrOfClass] = slot.value();
|
|
|
|
slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Recursively traverse the class hierarchy, globalizing slots and
|
|
|
|
// tracking linkage names for methods.
|
|
|
|
auto rootClassType =
|
|
|
|
symbolTable.lookup<ClassTypeOp>(rootNnModule.getClassName());
|
|
|
|
if (failed(recursivelyTraverseClassType(rootClassType)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod.
|
|
|
|
if (failed(rewriteMethods()))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Now that all we have finished converting to the new form, remove
|
|
|
|
// the original object graph.
|
|
|
|
removeObjectGraph();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
|
|
|
|
NnModuleOp rootNnModule;
|
|
|
|
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
continue;
|
|
|
|
if (rootNnModule) {
|
|
|
|
op.emitError()
|
|
|
|
.append("found more than one root module (module that is not a "
|
|
|
|
"child of any other module)")
|
|
|
|
.attachNote(rootNnModule.getLoc())
|
|
|
|
.append("see other root module here");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rootNnModule = op;
|
|
|
|
}
|
|
|
|
return rootNnModule;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() {
|
|
|
|
llvm::MapVector</*ClassTypeOp*/ Operation *, std::vector<NnModuleOp>>
|
|
|
|
classInstances;
|
|
|
|
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(op.getClassName());
|
|
|
|
classInstances[classType].push_back(op);
|
|
|
|
}
|
|
|
|
for (auto &p : classInstances) {
|
|
|
|
ClassTypeOp classType = cast<ClassTypeOp>(p.first);
|
|
|
|
ArrayRef<NnModuleOp> instances = p.second;
|
|
|
|
if (instances.size() > 1) {
|
|
|
|
// TODO: Improve this diagnostic based on user use cases.
|
|
|
|
// This is a user-facing diagnostic that enforces key invariants to
|
|
|
|
// our TorchScript subset.
|
|
|
|
auto diag = classType.emitError(
|
|
|
|
"class type has more than one instance: the current TorchScript "
|
|
|
|
"supported subset only allows single instances");
|
|
|
|
for (NnModuleOp instance : instances) {
|
|
|
|
diag.attachNote(instance.getLoc()) << "see instance here";
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
|
|
|
|
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
|
|
|
|
if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) {
|
|
|
|
return classType.emitError()
|
|
|
|
<< "reachable by multiple paths from root object: '<root>."
|
|
|
|
<< seenClassTypes[classType] << "' and '<root>."
|
|
|
|
<< pathToClassFromRoot << "'";
|
|
|
|
}
|
|
|
|
|
|
|
|
// For each attr, create a global slot for it.
|
|
|
|
for (auto attr : classType.getOps<AttrOp>()) {
|
|
|
|
nameStack.push_back(attr.name().str());
|
|
|
|
if (auto type = attr.type().dyn_cast<NnModuleType>()) {
|
2021-02-23 04:08:17 +08:00
|
|
|
if (failed(recursivelyTraverseClassType(
|
|
|
|
symbolTable.lookup<ClassTypeOp>(type.getClassName()))))
|
|
|
|
return failure();
|
2021-02-18 03:28:51 +08:00
|
|
|
} else {
|
|
|
|
auto linkageName = llvm::join(nameStack, ".");
|
|
|
|
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
|
2021-02-20 08:21:21 +08:00
|
|
|
attr->getLoc(), linkageName, /*sym_visibility=*/nullptr,
|
2021-03-09 21:58:03 +08:00
|
|
|
attr.type());
|
2021-02-20 08:21:21 +08:00
|
|
|
if (attr.isPrivate())
|
|
|
|
globalSlot.setVisibility(SymbolTable::Visibility::Private);
|
2021-02-18 03:28:51 +08:00
|
|
|
AttrOfClass attrOfClass = {classType, attr.name()};
|
|
|
|
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
|
|
|
|
globalSlotForAttr[attrOfClass] = globalSlot;
|
2021-02-26 07:54:51 +08:00
|
|
|
if (failed(populateGlobalSlotInitializer(globalSlot,
|
|
|
|
slotInitialValues[attrOfClass])))
|
|
|
|
return failure();
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
nameStack.pop_back();
|
|
|
|
}
|
|
|
|
// For each method, track the linkage name it will eventually have.
|
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
|
|
|
nameStack.push_back(method.name().str());
|
|
|
|
auto linkageName = llvm::join(nameStack, ".");
|
|
|
|
nameStack.pop_back();
|
|
|
|
if (!methodLinkageNames.insert({method.function(), linkageName}).second)
|
2021-02-23 04:08:17 +08:00
|
|
|
return method.emitError()
|
|
|
|
<< "unbound function shared by multiple methods";
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2021-02-26 07:54:51 +08:00
|
|
|
static bool hasMeaningfulObjectIdentity(Type type) {
|
|
|
|
return !type.isa<IntegerType, FloatType, Basicpy::BoolType,
|
|
|
|
Basicpy::BytesType, TensorType>();
|
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
|
2021-02-26 07:54:51 +08:00
|
|
|
LogicalResult
|
|
|
|
ObjectGraphGlobalizer::populateGlobalSlotInitializer(GlobalSlotOp globalSlot,
|
|
|
|
Value initialValue) {
|
|
|
|
OpBuilder builder(globalSlot.getContext());
|
|
|
|
builder.createBlock(&globalSlot.getRegion());
|
|
|
|
|
|
|
|
SmallPtrSet<Operation *, 6> needToClone;
|
|
|
|
SmallVector<Operation *> worklist = {initialValue.getDefiningOp()};
|
|
|
|
while (!worklist.empty()) {
|
|
|
|
Operation *op = worklist.pop_back_val();
|
|
|
|
if (!needToClone.insert(op).second)
|
2021-02-18 03:28:51 +08:00
|
|
|
continue;
|
2021-02-26 07:54:51 +08:00
|
|
|
for (Value operand : op->getOperands()) {
|
|
|
|
if (auto def = operand.getDefiningOp())
|
|
|
|
worklist.push_back(def);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
worklist.assign(needToClone.begin(), needToClone.end());
|
|
|
|
llvm::sort(worklist, [](Operation *lhs, Operation *rhs) {
|
|
|
|
return lhs->isBeforeInBlock(rhs);
|
|
|
|
});
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
for (Operation *op : worklist) {
|
|
|
|
builder.clone(*op, mapping);
|
|
|
|
for (Value result : op->getResults()) {
|
|
|
|
if (!hasMeaningfulObjectIdentity(result.getType()))
|
2021-02-18 03:28:51 +08:00
|
|
|
continue;
|
2021-02-26 07:54:51 +08:00
|
|
|
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
|
|
|
|
.second) {
|
|
|
|
return op->emitError()
|
|
|
|
<< "potentially-aliased value used to initialize multiple slots";
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2021-02-26 07:54:51 +08:00
|
|
|
builder.create<GlobalSlotInitOp>(globalSlot->getLoc(),
|
|
|
|
mapping.lookup(initialValue));
|
|
|
|
return success();
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
|
2021-02-19 09:10:17 +08:00
|
|
|
// Verify that a value conforms to the subset of allowed uses for
|
|
|
|
// !torch.nn.Module<"..."> types.
|
|
|
|
static LogicalResult verifyNnModuleValueUses(Value value) {
|
|
|
|
// Trivially true for non-module types.
|
|
|
|
if (!value.getType().isa<NnModuleType>())
|
|
|
|
return success();
|
|
|
|
for (Operation *op : value.getUsers()) {
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
if (isa<CallOp, PrimGetAttrOp, PrimSetAttrOp, PrimCallMethodOp>(op))
|
2021-02-19 09:10:17 +08:00
|
|
|
continue;
|
|
|
|
// TODO: Improve this based on real user use cases.
|
|
|
|
// This is a diagnostic that users will hit if they do not conform to
|
|
|
|
// the supported subset of TorchScript.
|
|
|
|
return op->emitError() << "unsupported use of a torch.nn.Module. Expected "
|
|
|
|
"only method calls or attribute get/set";
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
static std::string getNonMethodMangledFunctionName(StringRef originalName) {
|
|
|
|
return "__npcomp_priv_fn$" + originalName.str();
|
|
|
|
}
|
|
|
|
|
2021-02-19 09:10:17 +08:00
|
|
|
// Verify that `func` conforms to the subset of allowable method bodies
|
|
|
|
// that we can convert.
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
|
2021-02-19 09:10:17 +08:00
|
|
|
auto walkResult = func.walk([](Block *block) {
|
|
|
|
for (Value arg : block->getArguments()) {
|
|
|
|
if (failed(verifyNnModuleValueUses(arg)))
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
for (Operation &op : *block) {
|
|
|
|
for (Value result : op.getResults()) {
|
|
|
|
if (failed(verifyNnModuleValueUses(result)))
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
return failure(walkResult.wasInterrupted());
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
|
|
|
|
DenseMap<AttrOfClass, StringRef> linkageNames;
|
|
|
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
|
|
|
auto it = methodLinkageNames.find(method.function());
|
|
|
|
if (it == methodLinkageNames.end())
|
|
|
|
continue;
|
|
|
|
linkageNames[{classType, method.name()}] = it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We only handle a small subset of ops that conform with the set of
|
|
|
|
// assumptions that allow us to globalize the object graph. Anything that
|
|
|
|
// tries to treat modules as bona-fide objects and not just namespaces
|
|
|
|
// of methods with a single instance of the corresponding type just gets
|
|
|
|
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
|
|
|
|
// of modules, or there is an scf.if selecting between modules, etc.
|
2021-02-20 08:21:21 +08:00
|
|
|
SmallVector<Operation *> toErase;
|
2021-02-18 03:28:51 +08:00
|
|
|
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
|
|
|
|
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
|
|
|
primSetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
|
|
|
auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}];
|
|
|
|
OpBuilder(primSetAttr)
|
|
|
|
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
|
|
|
|
primSetAttr.value());
|
2021-02-20 08:21:21 +08:00
|
|
|
toErase.push_back(primSetAttr);
|
|
|
|
} else if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
|
2021-02-19 09:10:17 +08:00
|
|
|
// If the return value is NnModuleType, then we don't need to do anything.
|
|
|
|
// Our verification earlier ensured that there are no uses that
|
|
|
|
// we won't properly rewrite.
|
|
|
|
if (!primGetAttr.getType().isa<NnModuleType>()) {
|
|
|
|
auto classType =
|
|
|
|
symbolTable.lookup<ClassTypeOp>(primGetAttr.receiver()
|
|
|
|
.getType()
|
|
|
|
.cast<NnModuleType>()
|
|
|
|
.getClassName());
|
|
|
|
auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}];
|
|
|
|
auto globalSlotGet =
|
|
|
|
OpBuilder(primGetAttr)
|
|
|
|
.create<GlobalSlotGetOp>(primGetAttr.getLoc(),
|
|
|
|
primGetAttr.getType(),
|
|
|
|
globalSlot.sym_name());
|
|
|
|
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
|
|
|
|
}
|
2021-02-20 08:21:21 +08:00
|
|
|
toErase.push_back(primGetAttr);
|
|
|
|
} else if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
|
2021-02-18 03:28:51 +08:00
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
|
|
|
|
.getType()
|
|
|
|
.cast<NnModuleType>()
|
|
|
|
.getClassName());
|
|
|
|
StringRef linkageName = linkageNames[{classType, primCallMethod.name()}];
|
2021-02-19 09:10:17 +08:00
|
|
|
|
|
|
|
auto newOperands = llvm::to_vector<6>(
|
|
|
|
llvm::make_filter_range(primCallMethod.operands(), [](Value v) {
|
|
|
|
return !v.getType().isa<NnModuleType>();
|
|
|
|
}));
|
2021-02-18 03:28:51 +08:00
|
|
|
auto call = OpBuilder(primCallMethod)
|
|
|
|
.create<CallOp>(primCallMethod.getLoc(), linkageName,
|
2021-02-19 09:10:17 +08:00
|
|
|
primCallMethod.getType(), newOperands);
|
2021-02-18 03:28:51 +08:00
|
|
|
primCallMethod.replaceAllUsesWith(call);
|
2021-02-20 08:21:21 +08:00
|
|
|
toErase.push_back(primCallMethod);
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
} else if (auto callOp = dyn_cast<CallOp>(op)) {
|
|
|
|
auto newOperands = llvm::to_vector<6>(
|
|
|
|
llvm::make_filter_range(callOp.operands(), [](Value v) {
|
|
|
|
return !v.getType().isa<NnModuleType>();
|
|
|
|
}));
|
|
|
|
auto newCallOp = OpBuilder(callOp).create<CallOp>(
|
|
|
|
callOp.getLoc(), getNonMethodMangledFunctionName(callOp.callee()),
|
|
|
|
callOp.getResultTypes(), newOperands);
|
|
|
|
callOp.replaceAllUsesWith(newCallOp);
|
|
|
|
toErase.push_back(callOp);
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
};
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
struct MethodFuncRewrite {
|
|
|
|
bool isPrivate;
|
|
|
|
std::string linkageName;
|
|
|
|
};
|
|
|
|
|
|
|
|
DenseMap<FuncOp, MethodFuncRewrite> methodFuncRewrites;
|
2021-02-18 03:28:51 +08:00
|
|
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
|
|
|
auto it = methodLinkageNames.find(method.function());
|
|
|
|
if (it == methodLinkageNames.end())
|
|
|
|
continue;
|
|
|
|
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
methodFuncRewrites[func] =
|
|
|
|
MethodFuncRewrite{method.isPrivate(), it->second};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto func : module.getOps<FuncOp>()) {
|
|
|
|
if (failed(verifyFuncConformsToSubset(func)))
|
|
|
|
return failure();
|
|
|
|
func.walk(rewriteOpWithNnModuleTypeOperand);
|
|
|
|
for (Operation *op : toErase) {
|
|
|
|
op->dropAllDefinedValueUses();
|
|
|
|
op->erase();
|
|
|
|
}
|
|
|
|
toErase.clear();
|
|
|
|
SmallVector<unsigned> argsToErase;
|
|
|
|
for (auto arg : llvm::enumerate(func.getArguments())) {
|
|
|
|
if (!arg.value().getType().isa<NnModuleType>())
|
|
|
|
continue;
|
|
|
|
assert(arg.value().use_empty() && "all uses should have been removed");
|
|
|
|
argsToErase.push_back(arg.index());
|
|
|
|
}
|
|
|
|
func.eraseArguments(argsToErase);
|
|
|
|
// No need to handle the results, since we currently don't allow ReturnOp
|
|
|
|
// as a user of module types.
|
|
|
|
|
|
|
|
// Adjust the linkage names to adopt the linkage convention of this pass,
|
|
|
|
// namely that only objects accessible from a root !torch.nn.Module are
|
|
|
|
// possible external linkage candidates, and their linkage name is the
|
|
|
|
// dotted path from the root.
|
|
|
|
//
|
|
|
|
// Any other function gets a prefix to avoid collisions. These other
|
|
|
|
// functions correspond to free functions somewhere outside the module
|
|
|
|
// hierarchy.
|
|
|
|
auto it = methodFuncRewrites.find(func);
|
|
|
|
if (it != methodFuncRewrites.end()) {
|
|
|
|
if (!it->second.isPrivate)
|
2021-02-20 08:21:21 +08:00
|
|
|
func.setVisibility(SymbolTable::Visibility::Public);
|
Properly import the entire torch::jit::CompilationUnit
This primarily unlocks proper handling of free functions (that is,
functions that are not methods of any torch.nn.Module).
Recommended review order:
- `ivalue_importer.cpp` + `ivalue_import/functions*.py`
- `GlobalizeObjectGraph.cpp` + test case
- misc other stuff
The `torch::jit::CompilationUnit` is basically a backing store or
"context" holding all the possible functions in the program. The
previous code was not explicitly accessing this data structure, since it
just imported the `torch::jit::Function`'s that it saw attached to
methods.
Subtly, any time a TorchScript module called into a free function, the
free function gets incorporated into the torch::jit::CompilationUnit,
but doesn't show up anywhere when dumping the module, except in the
curious pattern:
```
%5 : Function = prim::Constant[name="adaptive_avg_pool2d"]()
%6 : Tensor = prim::CallFunction(%5, %input.1, %4)
```
That is, calls are indirect calls, and are accessed via `prim::Constant`
materializing a function object. Even stranger, the `name` attribute here
doesn't really even tell the full story -- it doesn't correspond to
anything. It turns out that the c10::FunctionType itself actually holds
a pointer to the `torch::jit::Function` in the compilation unit
directly (so there is actually no indirection in prim::CallMethod,
because any two values of the same FunctionType call the same
function!). E.g. when converting the IR to bytecode, the "name" is
ignored [code link](https://github.com/pytorch/pytorch/blob/1d6bd157902d4b1347a5d03122d02b407658e263/torch/csrc/jit/runtime/interpreter.cpp#L937).
We do import `prim::CallFunction` as a `std.call_indirect` though
because it's more braindead to do it that way (it gets canonicalized to
a direct call easily).
2021-02-27 08:20:35 +08:00
|
|
|
func.setName(it->second.linkageName);
|
|
|
|
} else {
|
|
|
|
func.setName(getNonMethodMangledFunctionName(func.getName()));
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
}
|
2021-02-20 08:21:21 +08:00
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void ObjectGraphGlobalizer::removeObjectGraph() {
|
|
|
|
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
2021-02-26 07:54:51 +08:00
|
|
|
if (!isa<FuncOp, GlobalSlotOp, ModuleTerminatorOp>(op)) {
|
2021-02-20 08:21:21 +08:00
|
|
|
op.dropAllDefinedValueUses();
|
2021-02-18 03:28:51 +08:00
|
|
|
op.erase();
|
2021-02-20 08:21:21 +08:00
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class GlobalizeObjectGraphPass
|
|
|
|
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph()))
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
|
|
|
|
return std::make_unique<GlobalizeObjectGraphPass>();
|
|
|
|
}
|