2021-02-18 03:28:51 +08:00
|
|
|
//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "PassDetail.h"
|
|
|
|
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
|
|
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
|
|
|
#include "llvm/ADT/MapVector.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::NPCOMP;
|
|
|
|
using namespace mlir::NPCOMP::Torch;
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// See the pass documentation for `torch-globalize-object-graph`.
|
|
|
|
class ObjectGraphGlobalizer {
|
|
|
|
public:
|
|
|
|
ObjectGraphGlobalizer(ModuleOp module);
|
|
|
|
LogicalResult globalizeObjectGraph();
|
|
|
|
|
|
|
|
private:
|
|
|
|
FailureOr<NnModuleOp> findRootNnModule();
|
|
|
|
LogicalResult checkSingleInstanceOfEachClass();
|
|
|
|
LogicalResult recursivelyTraverseClassType(ClassTypeOp classType);
|
|
|
|
void createInitializerFunc();
|
|
|
|
LogicalResult rewriteMethods();
|
|
|
|
void removeObjectGraph();
|
|
|
|
|
|
|
|
ModuleOp module;
|
|
|
|
SymbolTable symbolTable;
|
|
|
|
OpBuilder globalBuilder;
|
|
|
|
// The stack of attribute names we have traversed during our recursive
|
|
|
|
// traversal of the class/object hierarchy.
|
|
|
|
//
|
|
|
|
// Linkage names are calculated based on the set of attribute names traversed
|
|
|
|
// from the root class/module in the program.
|
|
|
|
SmallVector<std::string> nameStack;
|
|
|
|
|
|
|
|
// Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot
|
|
|
|
// ops. However, usually it is better to keep a map keyed on an ClassTypeOp
|
|
|
|
// + attr name since frequently that is all one has access to and it
|
|
|
|
// would be tedious to scan the body of the ClassTypeOp for the torch.attr
|
|
|
|
// with the corresponding name.
|
|
|
|
using AttrOfClass =
|
|
|
|
std::pair</*ClassTypeOp*/ Operation *, /*attr name*/ StringRef>;
|
|
|
|
// The initial value associated with an attribute of a class.
|
|
|
|
// Since we only allow a single instance of a class, this is equivalent to
|
|
|
|
// the initial value of the unique slot corresponding to that attr.
|
|
|
|
DenseMap<AttrOfClass, Value> slotInitialValues;
|
|
|
|
// The inverse map of `slotInitialValues`.
|
|
|
|
// Many attributes can have the same initial value, so the value type
|
|
|
|
// is a vector.
|
|
|
|
DenseMap<Value, std::vector<AttrOfClass>> slotInitialValuesInverseMap;
|
|
|
|
|
|
|
|
// The torch.global_slot corresponding to each torch.attr/torch.slot.
|
|
|
|
DenseMap<AttrOfClass, GlobalSlotOp> globalSlotForAttr;
|
|
|
|
// The linkage name (value) for the function with symbol name (key).
|
|
|
|
DenseMap<StringRef, std::string> methodLinkageNames;
|
|
|
|
|
|
|
|
// The set of class types that have already been processed.
|
|
|
|
// Used for diagnostics.
|
|
|
|
// The map value is the original path from the root that we found it at.
|
|
|
|
DenseMap</*ClassTypeOp*/ Operation *, std::string> seenClassTypes;
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module)
|
|
|
|
: module(module), symbolTable(module),
|
|
|
|
globalBuilder(module.getBodyRegion()) {}
|
|
|
|
|
|
|
|
LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() {
|
|
|
|
// We require there to be a unique root !torch.nn.Module.
|
|
|
|
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule();
|
|
|
|
if (failed(maybeRootNnModule))
|
|
|
|
return failure();
|
|
|
|
NnModuleOp rootNnModule = *maybeRootNnModule;
|
|
|
|
if (!rootNnModule)
|
|
|
|
return module.emitError()
|
|
|
|
<< "module does not contain a root torch.nn_module";
|
|
|
|
|
|
|
|
// We require one instance of each class. That is, there is a single
|
|
|
|
// torch.nn_module for each torch.class_type.
|
|
|
|
if (failed(checkSingleInstanceOfEachClass()))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
for (NnModuleOp nnModule : module.getOps<NnModuleOp>()) {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(nnModule.getClassName());
|
|
|
|
for (auto slot : nnModule.getOps<SlotOp>()) {
|
|
|
|
AttrOfClass attrOfClass = {classType, slot.name()};
|
|
|
|
slotInitialValues[attrOfClass] = slot.value();
|
|
|
|
slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Recursively traverse the class hierarchy, globalizing slots and
|
|
|
|
// tracking linkage names for methods.
|
|
|
|
auto rootClassType =
|
|
|
|
symbolTable.lookup<ClassTypeOp>(rootNnModule.getClassName());
|
|
|
|
if (failed(recursivelyTraverseClassType(rootClassType)))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Move all slot initial values into an initializer func.
|
|
|
|
createInitializerFunc();
|
|
|
|
|
|
|
|
// Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod.
|
|
|
|
if (failed(rewriteMethods()))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Now that all we have finished converting to the new form, remove
|
|
|
|
// the original object graph.
|
|
|
|
removeObjectGraph();
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
|
|
|
|
NnModuleOp rootNnModule;
|
|
|
|
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
|
|
|
if (!op.use_empty())
|
|
|
|
continue;
|
|
|
|
if (rootNnModule) {
|
|
|
|
op.emitError()
|
|
|
|
.append("found more than one root module (module that is not a "
|
|
|
|
"child of any other module)")
|
|
|
|
.attachNote(rootNnModule.getLoc())
|
|
|
|
.append("see other root module here");
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
rootNnModule = op;
|
|
|
|
}
|
|
|
|
return rootNnModule;
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() {
|
|
|
|
llvm::MapVector</*ClassTypeOp*/ Operation *, std::vector<NnModuleOp>>
|
|
|
|
classInstances;
|
|
|
|
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(op.getClassName());
|
|
|
|
classInstances[classType].push_back(op);
|
|
|
|
}
|
|
|
|
for (auto &p : classInstances) {
|
|
|
|
ClassTypeOp classType = cast<ClassTypeOp>(p.first);
|
|
|
|
ArrayRef<NnModuleOp> instances = p.second;
|
|
|
|
if (instances.size() > 1) {
|
|
|
|
// TODO: Improve this diagnostic based on user use cases.
|
|
|
|
// This is a user-facing diagnostic that enforces key invariants to
|
|
|
|
// our TorchScript subset.
|
|
|
|
auto diag = classType.emitError(
|
|
|
|
"class type has more than one instance: the current TorchScript "
|
|
|
|
"supported subset only allows single instances");
|
|
|
|
for (NnModuleOp instance : instances) {
|
|
|
|
diag.attachNote(instance.getLoc()) << "see instance here";
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
|
|
|
|
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
|
|
|
|
if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) {
|
|
|
|
return classType.emitError()
|
|
|
|
<< "reachable by multiple paths from root object: '<root>."
|
|
|
|
<< seenClassTypes[classType] << "' and '<root>."
|
|
|
|
<< pathToClassFromRoot << "'";
|
|
|
|
}
|
|
|
|
|
|
|
|
// For each attr, create a global slot for it.
|
|
|
|
for (auto attr : classType.getOps<AttrOp>()) {
|
|
|
|
nameStack.push_back(attr.name().str());
|
|
|
|
if (auto type = attr.type().dyn_cast<NnModuleType>()) {
|
2021-02-23 04:08:17 +08:00
|
|
|
if (failed(recursivelyTraverseClassType(
|
|
|
|
symbolTable.lookup<ClassTypeOp>(type.getClassName()))))
|
|
|
|
return failure();
|
2021-02-18 03:28:51 +08:00
|
|
|
} else {
|
|
|
|
auto linkageName = llvm::join(nameStack, ".");
|
|
|
|
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
|
2021-02-20 08:21:21 +08:00
|
|
|
attr->getLoc(), linkageName, /*sym_visibility=*/nullptr,
|
|
|
|
TypeAttr::get(attr.type()));
|
|
|
|
if (attr.isPrivate())
|
|
|
|
globalSlot.setVisibility(SymbolTable::Visibility::Private);
|
2021-02-18 03:28:51 +08:00
|
|
|
AttrOfClass attrOfClass = {classType, attr.name()};
|
|
|
|
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
|
|
|
|
globalSlotForAttr[attrOfClass] = globalSlot;
|
|
|
|
}
|
|
|
|
nameStack.pop_back();
|
|
|
|
}
|
|
|
|
// For each method, track the linkage name it will eventually have.
|
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
|
|
|
nameStack.push_back(method.name().str());
|
|
|
|
auto linkageName = llvm::join(nameStack, ".");
|
|
|
|
nameStack.pop_back();
|
|
|
|
if (!methodLinkageNames.insert({method.function(), linkageName}).second)
|
2021-02-23 04:08:17 +08:00
|
|
|
return method.emitError()
|
|
|
|
<< "unbound function shared by multiple methods";
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void ObjectGraphGlobalizer::createInitializerFunc() {
|
|
|
|
auto loc = module.getLoc();
|
|
|
|
auto func = globalBuilder.create<FuncOp>(
|
|
|
|
loc, GlobalSlotOp::getGlobalSlotInitializerFuncName(),
|
|
|
|
globalBuilder.getFunctionType({}, {}));
|
|
|
|
OpBuilder builder(func.getContext());
|
|
|
|
Block *body = builder.createBlock(&func.getBody());
|
|
|
|
|
|
|
|
SmallVector<Operation *> opsToMove;
|
|
|
|
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
|
|
|
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, FuncOp, ModuleTerminatorOp>(
|
|
|
|
&op))
|
|
|
|
continue;
|
|
|
|
op.moveBefore(body, body->end());
|
|
|
|
for (Value result : llvm::make_early_inc_range(op.getResults())) {
|
|
|
|
auto it = slotInitialValuesInverseMap.find(result);
|
|
|
|
if (it == slotInitialValuesInverseMap.end())
|
|
|
|
continue;
|
|
|
|
for (AttrOfClass attrOfClass : it->second) {
|
|
|
|
GlobalSlotOp globalSlot = globalSlotForAttr[attrOfClass];
|
|
|
|
OpBuilder::atBlockEnd(body).create<GlobalSlotSetOp>(
|
|
|
|
globalSlot.getLoc(), globalSlot.sym_name(), result);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
builder.create<ReturnOp>(loc);
|
|
|
|
}
|
|
|
|
|
2021-02-19 09:10:17 +08:00
|
|
|
// Verify that a value conforms to the subset of allowed uses for
|
|
|
|
// !torch.nn.Module<"..."> types.
|
|
|
|
static LogicalResult verifyNnModuleValueUses(Value value) {
|
|
|
|
// Trivially true for non-module types.
|
|
|
|
if (!value.getType().isa<NnModuleType>())
|
|
|
|
return success();
|
|
|
|
for (Operation *op : value.getUsers()) {
|
|
|
|
if (isa<PrimGetAttrOp, PrimSetAttrOp, PrimCallMethodOp>(op))
|
|
|
|
continue;
|
|
|
|
// TODO: Improve this based on real user use cases.
|
|
|
|
// This is a diagnostic that users will hit if they do not conform to
|
|
|
|
// the supported subset of TorchScript.
|
|
|
|
return op->emitError() << "unsupported use of a torch.nn.Module. Expected "
|
|
|
|
"only method calls or attribute get/set";
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Verify that `func` conforms to the subset of allowable method bodies
|
|
|
|
// that we can convert.
|
|
|
|
static LogicalResult verifyMethodConformsToSubset(FuncOp func) {
|
|
|
|
auto walkResult = func.walk([](Block *block) {
|
|
|
|
for (Value arg : block->getArguments()) {
|
|
|
|
if (failed(verifyNnModuleValueUses(arg)))
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
for (Operation &op : *block) {
|
|
|
|
for (Value result : op.getResults()) {
|
|
|
|
if (failed(verifyNnModuleValueUses(result)))
|
|
|
|
return WalkResult::interrupt();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
|
|
});
|
|
|
|
return failure(walkResult.wasInterrupted());
|
|
|
|
}
|
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
|
|
|
|
DenseMap<AttrOfClass, StringRef> linkageNames;
|
|
|
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
|
|
|
auto it = methodLinkageNames.find(method.function());
|
|
|
|
if (it == methodLinkageNames.end())
|
|
|
|
continue;
|
|
|
|
linkageNames[{classType, method.name()}] = it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// We only handle a small subset of ops that conform with the set of
|
|
|
|
// assumptions that allow us to globalize the object graph. Anything that
|
|
|
|
// tries to treat modules as bona-fide objects and not just namespaces
|
|
|
|
// of methods with a single instance of the corresponding type just gets
|
|
|
|
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
|
|
|
|
// of modules, or there is an scf.if selecting between modules, etc.
|
2021-02-20 08:21:21 +08:00
|
|
|
SmallVector<Operation *> toErase;
|
2021-02-18 03:28:51 +08:00
|
|
|
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
|
|
|
|
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
|
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
|
|
|
primSetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
|
|
|
auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}];
|
|
|
|
OpBuilder(primSetAttr)
|
|
|
|
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
|
|
|
|
primSetAttr.value());
|
2021-02-20 08:21:21 +08:00
|
|
|
toErase.push_back(primSetAttr);
|
|
|
|
} else if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
|
2021-02-19 09:10:17 +08:00
|
|
|
// If the return value is NnModuleType, then we don't need to do anything.
|
|
|
|
// Our verification earlier ensured that there are no uses that
|
|
|
|
// we won't properly rewrite.
|
|
|
|
if (!primGetAttr.getType().isa<NnModuleType>()) {
|
|
|
|
auto classType =
|
|
|
|
symbolTable.lookup<ClassTypeOp>(primGetAttr.receiver()
|
|
|
|
.getType()
|
|
|
|
.cast<NnModuleType>()
|
|
|
|
.getClassName());
|
|
|
|
auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}];
|
|
|
|
auto globalSlotGet =
|
|
|
|
OpBuilder(primGetAttr)
|
|
|
|
.create<GlobalSlotGetOp>(primGetAttr.getLoc(),
|
|
|
|
primGetAttr.getType(),
|
|
|
|
globalSlot.sym_name());
|
|
|
|
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
|
|
|
|
}
|
2021-02-20 08:21:21 +08:00
|
|
|
toErase.push_back(primGetAttr);
|
|
|
|
} else if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
|
2021-02-18 03:28:51 +08:00
|
|
|
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
|
|
|
|
.getType()
|
|
|
|
.cast<NnModuleType>()
|
|
|
|
.getClassName());
|
|
|
|
StringRef linkageName = linkageNames[{classType, primCallMethod.name()}];
|
2021-02-19 09:10:17 +08:00
|
|
|
|
|
|
|
auto newOperands = llvm::to_vector<6>(
|
|
|
|
llvm::make_filter_range(primCallMethod.operands(), [](Value v) {
|
|
|
|
return !v.getType().isa<NnModuleType>();
|
|
|
|
}));
|
2021-02-18 03:28:51 +08:00
|
|
|
auto call = OpBuilder(primCallMethod)
|
|
|
|
.create<CallOp>(primCallMethod.getLoc(), linkageName,
|
2021-02-19 09:10:17 +08:00
|
|
|
primCallMethod.getType(), newOperands);
|
2021-02-18 03:28:51 +08:00
|
|
|
primCallMethod.replaceAllUsesWith(call);
|
2021-02-20 08:21:21 +08:00
|
|
|
toErase.push_back(primCallMethod);
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
|
|
|
for (auto method : classType.getOps<MethodOp>()) {
|
|
|
|
auto it = methodLinkageNames.find(method.function());
|
|
|
|
if (it == methodLinkageNames.end())
|
|
|
|
continue;
|
|
|
|
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
|
2021-02-19 09:10:17 +08:00
|
|
|
if (failed(verifyMethodConformsToSubset(func)))
|
|
|
|
return failure();
|
2021-02-20 08:21:21 +08:00
|
|
|
if (!method.isPrivate())
|
|
|
|
func.setVisibility(SymbolTable::Visibility::Public);
|
2021-02-18 03:28:51 +08:00
|
|
|
func.setName(it->second);
|
|
|
|
func.walk(rewriteOpWithNnModuleTypeOperand);
|
2021-02-20 08:21:21 +08:00
|
|
|
for (Operation *op : toErase) {
|
|
|
|
op->dropAllDefinedValueUses();
|
|
|
|
op->erase();
|
|
|
|
}
|
|
|
|
toErase.clear();
|
2021-02-18 03:28:51 +08:00
|
|
|
SmallVector<unsigned> argsToErase;
|
|
|
|
for (auto arg : llvm::enumerate(func.getArguments())) {
|
|
|
|
if (!arg.value().getType().isa<NnModuleType>())
|
|
|
|
continue;
|
2021-02-19 09:10:17 +08:00
|
|
|
assert(arg.value().use_empty() && "all uses should have been removed");
|
2021-02-18 03:28:51 +08:00
|
|
|
argsToErase.push_back(arg.index());
|
|
|
|
}
|
|
|
|
func.eraseArguments(argsToErase);
|
2021-02-19 09:10:17 +08:00
|
|
|
// No need to handle the results, since we currently don't allow ReturnOp
|
|
|
|
// as a user of module types.
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
}
|
2021-02-20 08:21:21 +08:00
|
|
|
|
2021-02-18 03:28:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void ObjectGraphGlobalizer::removeObjectGraph() {
|
|
|
|
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
2021-02-20 08:21:21 +08:00
|
|
|
if (isa<ClassTypeOp, NnModuleOp>(op)) {
|
|
|
|
op.dropAllDefinedValueUses();
|
2021-02-18 03:28:51 +08:00
|
|
|
op.erase();
|
2021-02-20 08:21:21 +08:00
|
|
|
}
|
2021-02-18 03:28:51 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class GlobalizeObjectGraphPass
|
|
|
|
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
|
|
|
|
void runOnOperation() override {
|
|
|
|
if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph()))
|
|
|
|
return signalPassFailure();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::NPCOMP::Torch::createGlobalizeObjectGraphPass() {
|
|
|
|
return std::make_unique<GlobalizeObjectGraphPass>();
|
|
|
|
}
|