torch-mlir/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp

710 lines
27 KiB
C++

//===- GlobalizeObjectGraph.cpp ----------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
NnModuleOp rootNnModule;
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
if (!op.use_empty())
continue;
if (rootNnModule) {
op.emitError()
.append("found more than one root module (module that is not a "
"child of any other module)")
.attachNote(rootNnModule.getLoc())
.append("see other root module here");
return failure();
}
rootNnModule = op;
}
if (!rootNnModule) {
module.emitError() << "module does not contain a root torch.nn_module";
return failure();
}
return rootNnModule;
}
//===----------------------------------------------------------------------===//
// Object graph recursive traversal.
//===----------------------------------------------------------------------===//
namespace {
struct LinkageInfo {
std::string linkageName;
bool isPrivate;
};
} // namespace
namespace {
/// Calculates the linkage names of all the potentially exported objects in the
/// module and also creates GlobalSlotOp's for each SlotOp and tracks their
/// associations.
///
/// The mechanics of both of these tasks involve the same object graph
/// traversal, so it's useful to roll them together.
class ObjectGraphInfo {
public:
ObjectGraphInfo(ModuleOp module)
: globalSlotBuilder(module.getBodyRegion()), symbolTable(module) {}
LogicalResult initialize(NnModuleOp rootNnModule) {
if (failed(collectUsedSlots()))
return failure();
return recursivelyTraverse(rootNnModule);
}
LinkageInfo getSlotLinkageInfo(SlotOp op) {
auto it = slotLinkageInfo.find(op);
assert(it != slotLinkageInfo.end());
return it->second;
}
std::optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
func::FuncOp methodFunc) {
auto it = funcLinkageInfo.find({instance, methodFunc});
if (it == funcLinkageInfo.end())
return std::nullopt;
return it->second;
}
GlobalSlotOp getGlobalSlotFor(SlotOp slot) {
auto it = slotToGlobalSlot.find(slot);
assert(it != slotToGlobalSlot.end() && "didn't create global slot");
return it->second;
}
llvm::MapVector<StringAttr, Value> &getGlobalSlotInitialValues() {
return globalSlotInitialValues;
}
private:
LogicalResult collectUsedSlots() {
// Collect all the slots in each module.
// moduleClassNameToSlots tracks, for each class, for each attribute, the
// set of slot instances that belong to that attribute. E.g. if there are
// two instances of a class "Foo" with an attribute "a", then there will be
// two SlotOps in the inner vector of moduleClassNameToSlots["Foo"]["a"].
// This is not precise -- in the code below it effectively results in the
// conservative assumption that all instances of a class might reach all
// GetAttr ops on that type.
llvm::StringMap<llvm::StringMap<std::vector<SlotOp>>>
moduleClassNameToSlots;
symbolTable.getOp()->walk([&](NnModuleOp moduleOp) {
llvm::StringMap<SlotOp> nameToSlot;
auto &slotNameToSlots = moduleClassNameToSlots[moduleOp.getClassName()];
for (auto slotOp : moduleOp.getOps<SlotOp>())
slotNameToSlots[slotOp.getName()].push_back(slotOp);
});
// Find all the module slots that are accessed through `PrimGetAttrOp` or
// `PrimSetAttrOp`.
symbolTable.getOp()->walk([&](Operation *op) {
if (!isa<PrimGetAttrOp, PrimSetAttrOp>(op))
return;
Value module;
StringRef slotName;
if (auto getAttrOp = llvm::dyn_cast<PrimGetAttrOp>(op)) {
module = getAttrOp.getReceiver();
slotName = getAttrOp.getName();
} else {
auto setAttrOp = cast<PrimSetAttrOp>(op);
module = setAttrOp.getReceiver();
slotName = setAttrOp.getName();
}
auto moduleType = module.getType().cast<NnModuleType>();
auto slots = moduleClassNameToSlots.find(moduleType.getClassName());
// TODO: Improve verifier so that this can never happen
if (slots == moduleClassNameToSlots.end())
op->emitError() << "Reference to non-existing module type "
<< moduleType.getClassName();
auto &slotNameToSlots = slots->getValue();
auto slotIt = slotNameToSlots.find(slotName);
// TODO: Improve verifier so that this can never happen
if (slotIt == slotNameToSlots.end())
op->emitError() << "Reference to non-existing module slot " << slotName
<< "in " << moduleType.getClassName();
for (SlotOp slotOp : slotIt->getValue())
usedSlots.insert(slotOp);
});
return success();
}
LogicalResult recursivelyTraverse(NnModuleOp nnModule) {
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
if (!seenNnModules.insert({nnModule, pathToClassFromRoot}).second) {
return nnModule.emitError()
<< "reachable by multiple paths from root object: '<root>."
<< seenNnModules[nnModule] << "' and '<root>."
<< pathToClassFromRoot << "'";
}
auto classType = symbolTable.lookup<ClassTypeOp>(
nnModule.getType().cast<NnModuleType>().getClassName());
for (auto t :
llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) {
auto slot = std::get<0>(t);
auto attr = std::get<1>(t);
nameStack.push_back(attr.getName().str());
if (attr.getType().isa<NnModuleType>()) {
if (failed(recursivelyTraverse(
slot.getValue().getDefiningOp<NnModuleOp>())))
return failure();
} else if (usedSlots.find(slot) != usedSlots.end()) {
// Only create the GlobalSlotOp if the slot is used at all.
std::string linkageName = llvm::join(nameStack, ".");
auto globalSlot = globalSlotBuilder.create<GlobalSlotOp>(
slot.getLoc(), linkageName,
/*sym_visibility=*/nullptr, attr.getType());
if (attr.getIsPrivate())
globalSlot.setVisibility(SymbolTable::Visibility::Private);
assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end());
slotToGlobalSlot[slot] = globalSlot;
slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.getIsPrivate()};
globalSlotInitialValues[globalSlot.getSymNameAttr()] = slot.getValue();
}
nameStack.pop_back();
}
for (auto method : classType.getOps<MethodOp>()) {
nameStack.push_back(method.getName().str());
funcLinkageInfo[{
nnModule, symbolTable.lookup<func::FuncOp>(method.getFunction())}] =
LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()};
nameStack.pop_back();
}
return success();
}
// Builder for creating GlobalSlotOp's in the module.
OpBuilder globalSlotBuilder;
// Symbol table for the module.
SymbolTable symbolTable;
// The set of NnModuleOp's that have already been processed.
// Used for diagnostics.
// The map value is the original path from the root that we found it at.
DenseMap<NnModuleOp, std::string> seenNnModules;
// The stack of attribute names we have traversed during our recursive
// traversal of the class/object hierarchy.
//
// Linkage names are calculated based on the set of attribute names traversed
// from the root class/module in the program.
std::vector<std::string> nameStack;
// Linkage info for each SlotOp in the program.
DenseMap<SlotOp, LinkageInfo> slotLinkageInfo;
// Linkage info for each method in the program. Since we are going to be
// monomorphizing all the functions, we also need to key this off of the
// instance (NnModuleOp) that the func is monomorphized for.
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
// The corresponding GlobalSlotOp for each SlotOp in the program.
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
// The initializing value for each GlobalSlotOp.
// This is a MapVector to keep the order deterministic.
llvm::MapVector<StringAttr, Value> globalSlotInitialValues;
// Used to keep track of all the used torch slots so that the restrictions can
// be applied to those slots only.
DenseSet<SlotOp> usedSlots;
};
} // namespace
LogicalResult
createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable,
ObjectGraphInfo &objectGraphInfo) {
auto builder = OpBuilder::atBlockBegin(module.getBody());
auto moduleInitializer =
builder.create<GlobalSlotModuleInitializerOp>(module.getLoc());
Block *body = builder.createBlock(&moduleInitializer.getInitializer());
builder.setInsertionPointToEnd(body);
SmallVector<Operation *> opsToMove;
for (Operation &op : *module.getBody()) {
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, func::FuncOp,
GlobalSlotModuleInitializerOp>(op))
continue;
opsToMove.push_back(&op);
}
IRMapping mapping;
for (Operation *op : opsToMove) {
// The ops are used by `torch.slot` ops in the enclosing module.
// Cloning avoids needing to handle those uses specially.
builder.clone(*op, mapping);
}
SmallVector<Attribute> slotSymNames;
SmallVector<Value> initialValues;
for (auto &kv : objectGraphInfo.getGlobalSlotInitialValues()) {
StringAttr symName = kv.first;
Value initializer = kv.second;
slotSymNames.push_back(FlatSymbolRefAttr::get(symName));
initialValues.push_back(mapping.lookup(initializer));
}
builder.create<InitializeGlobalSlotsOp>(
moduleInitializer.getLoc(),
ArrayAttr::get(module.getContext(), slotSymNames), initialValues);
return success();
}
//===----------------------------------------------------------------------===//
// Monomorphization.
//===----------------------------------------------------------------------===//
namespace {
// When used in an Monomorphization, indicates that the arg at `argIndex` will
// correspond to instance `instance.
struct ArgInstance {
int argIndex;
Value instance; // Result of an NnModuleOp.
};
static llvm::hash_code hash_value(const ArgInstance &argInstance) {
return llvm::hash_combine(argInstance.argIndex, argInstance.instance);
}
static bool operator==(const ArgInstance &lhs, const ArgInstance &rhs) {
return std::make_tuple(lhs.argIndex, lhs.instance) ==
std::make_tuple(rhs.argIndex, rhs.instance);
}
} // namespace
namespace {
// Record indicating that a particular function must be monomorphized for the
// given ArgInstance's, which involves deleting those arguments and specializing
// all their uses to operate on GlobalSlotOp's that we have created for the
// SlotOp's of the NnModuleOp instances.
//
// NOTE: Unlike the more traditional use of monomorphization to mean a single
// *type* is being specialized for, here we are specializing for a specific
// *instance*. This still fits the definition of monomorphization though, albeit
// with each instance being considered to have a maximally refined type which is
// a set with a single element (just this instance). This does not correspond to
// any notion of "type" that we have in the IR, but still fits the formal
// definition.
struct Monomorphization {
func::FuncOp func;
std::vector<ArgInstance> argInstances;
};
} // namespace
template <> struct llvm::DenseMapInfo<Monomorphization> {
static Monomorphization getEmptyKey() {
return Monomorphization{nullptr, {ArgInstance{-1, nullptr}}};
}
static Monomorphization getTombstoneKey() {
return Monomorphization{nullptr, {ArgInstance{-2, nullptr}}};
}
static unsigned getHashValue(Monomorphization val) {
return llvm::hash_combine(val.func.getAsOpaquePointer(),
llvm::hash_combine_range(val.argInstances.begin(),
val.argInstances.end()));
}
static bool isEqual(Monomorphization lhs, Monomorphization rhs) {
return lhs.func == rhs.func &&
std::equal(lhs.argInstances.begin(), lhs.argInstances.end(),
rhs.argInstances.begin(), rhs.argInstances.end());
}
};
// Populate `mapping` such that values of NnModuleType in the function are
// mapped to appropriate global objects of NnModuleType.
//
// This generalizes to a full abstract interpretation of the function, but
// currently only analyzes a subset of ops.
static LogicalResult analyzeInstances(func::FuncOp func,
ArrayRef<ArgInstance> argInstances,
IRMapping &mapping) {
for (auto &argInstance : argInstances)
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
auto walkResult = func.walk([&](PrimGetAttrOp op) {
if (!op.getType().isa<NnModuleType>())
return WalkResult::advance();
auto instance = mapping.lookupOrNull(op.getReceiver());
assert(instance && "verifyFuncConformsToSubset should ensure this");
for (auto slot : instance.getDefiningOp<NnModuleOp>().getOps<SlotOp>()) {
if (slot.getName() == op.getName()) {
mapping.map(op, slot.getValue());
break;
}
}
return WalkResult::advance();
});
return success(!walkResult.wasInterrupted());
}
static FailureOr<Monomorphization>
createMonomorphizationForCall(func::CallOp op, IRMapping &mapping,
SymbolTable &symbolTable) {
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
Monomorphization monomorphization;
monomorphization.func = func;
for (auto operand : llvm::enumerate(op->getOperands())) {
if (!operand.value().getType().isa<NnModuleType>())
continue;
Value instance = mapping.lookupOrNull(operand.value());
assert(instance && "verifyFuncConformsToSubset should ensure this");
monomorphization.argInstances.push_back(
ArgInstance{static_cast<int>(operand.index()), instance});
}
return monomorphization;
}
namespace {
class MonomorphizationTracker {
public:
MonomorphizationTracker(ModuleOp module)
: module(module), symbolTable(module) {}
LogicalResult
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
for (auto func : module.getOps<func::FuncOp>()) {
Monomorphization monomorphization;
monomorphization.func = func;
bool canTriviallyMonomorphize = true;
for (auto arg : llvm::enumerate(func.getArguments())) {
auto type = arg.value().getType().dyn_cast<NnModuleType>();
if (!type)
continue;
auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName());
auto &classTypeInstances = instances[classType];
if (classTypeInstances.size() != 1) {
canTriviallyMonomorphize = false;
break;
}
monomorphization.argInstances.push_back(
{static_cast<int>(arg.index()), classTypeInstances[0]});
}
if (canTriviallyMonomorphize) {
dirtyMonomorphizations.push_back(monomorphization);
monomorphizations.insert(monomorphization);
}
}
while (!dirtyMonomorphizations.empty()) {
Monomorphization dirty = dirtyMonomorphizations.pop_back_val();
if (failed(generateNewMonomorphizations(dirty)))
return failure();
}
return success();
}
llvm::SetVector<Monomorphization> &getMonomorphizations() {
return monomorphizations;
}
private:
LogicalResult generateNewMonomorphizations(const Monomorphization &m) {
auto func = m.func;
IRMapping mapping;
if (failed(analyzeInstances(func, m.argInstances, mapping)))
return failure();
auto walkResult = func.walk([&](func::CallOp op) {
FailureOr<Monomorphization> maybeMonomorphization =
createMonomorphizationForCall(op, mapping, symbolTable);
if (failed(maybeMonomorphization))
return WalkResult::interrupt();
if (monomorphizations.insert(*maybeMonomorphization))
dirtyMonomorphizations.push_back(*maybeMonomorphization);
return WalkResult::advance();
});
return success(!walkResult.wasInterrupted());
}
ModuleOp module;
SymbolTable symbolTable;
SmallVector<Monomorphization> dirtyMonomorphizations;
llvm::SetVector<Monomorphization> monomorphizations;
};
} // namespace
// Verify that a value conforms to the subset of allowed uses for
// !torch.nn.Module<"..."> types.
static LogicalResult verifyNnModuleValueUses(Value value) {
// Trivially succeed for non-module types.
if (!value.getType().isa<NnModuleType>())
return success();
for (Operation *op : value.getUsers()) {
if (isa<func::CallOp, PrimGetAttrOp>(op))
continue;
// Only allow `value` as the receiver.
if (isa<PrimSetAttrOp>(op) && cast<PrimSetAttrOp>(op).getValue() != value)
continue;
// TODO: Improve this based on real user use cases.
// This is a diagnostic that users will hit if they do not conform to
// the supported subset of TorchScript.
return op->emitError() << "unsupported use of a torch.nn.Module. Expected "
"only method calls or attribute get/set";
}
return success();
}
// Verify that `func` conforms to the subset of allowable method bodies
// that we can convert.
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
auto walkResult = func.walk([&](Block *block) {
for (Value arg : block->getArguments())
if (failed(verifyNnModuleValueUses(arg)))
return WalkResult::interrupt();
for (Operation &op : *block)
for (Value result : op.getResults())
if (failed(verifyNnModuleValueUses(result)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!walkResult.wasInterrupted());
}
static LogicalResult
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
MonomorphizationTracker &tracker) {
DenseMap<func::FuncOp, int> numMonomorphizations;
for (auto &monomorphization : tracker.getMonomorphizations()) {
numMonomorphizations[monomorphization.func] += 1;
}
bool sawError = false;
for (auto classType : module.getOps<ClassTypeOp>()) {
for (auto method : classType.getOps<MethodOp>()) {
if (!method.getIsPrivate()) {
if (numMonomorphizations[symbolTable.lookup<func::FuncOp>(
method.getFunction())] > 1) {
method.emitError()
<< "public function with multiple monomorphizations";
sawError = true;
}
}
}
}
return success(!sawError);
}
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
// `mapping` to corresponding global instances.
static LogicalResult rewriteMonomorphizedFuncClone(
func::FuncOp func, IRMapping mapping, SymbolTable &symbolTable,
DenseMap<Monomorphization, func::FuncOp> &newFuncs,
ObjectGraphInfo &objectGraphInfo) {
SmallVector<Operation *> toErase;
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
auto instance =
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
SlotOp affectedSlot;
for (auto slot : instance.getOps<SlotOp>()) {
if (slot.getName() == op.getName())
affectedSlot = slot;
}
OpBuilder(op).create<GlobalSlotSetOp>(
op.getLoc(),
objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(),
op.getValue());
toErase.push_back(op);
return WalkResult::advance();
};
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
if (!op.getType().isa<NnModuleType>()) {
auto instance =
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
SlotOp affectedSlot;
for (auto slot : instance.getOps<SlotOp>()) {
if (slot.getName() == op.getName())
affectedSlot = slot;
}
auto newOp = OpBuilder(op).create<GlobalSlotGetOp>(
op.getLoc(), op.getType(),
objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName());
op.replaceAllUsesWith(&*newOp);
}
toErase.push_back(op);
return WalkResult::advance();
};
auto handleCall = [&](func::CallOp op) {
FailureOr<Monomorphization> maybeMonomorphization =
createMonomorphizationForCall(op, mapping, symbolTable);
if (failed(maybeMonomorphization))
return WalkResult::interrupt();
Monomorphization monomorphization = std::move(*maybeMonomorphization);
auto newArguments = llvm::to_vector<6>(
llvm::make_filter_range(op->getOperands(), [](Value v) {
return !v.getType().isa<NnModuleType>();
}));
assert(newFuncs.find(monomorphization) != newFuncs.end());
auto newOp = OpBuilder(op).create<func::CallOp>(
op.getLoc(), newFuncs[monomorphization], newArguments);
op.replaceAllUsesWith(newOp);
toErase.push_back(op);
return WalkResult::advance();
};
auto walkResult = func.walk([&](Operation *op) {
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op))
return handlePrimSetAttr(primSetAttr);
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op))
return handlePrimGetAttr(primGetAttr);
if (auto call = dyn_cast<func::CallOp>(op))
return handleCall(call);
return WalkResult::advance();
});
for (auto op : toErase) {
op->dropAllDefinedValueUses();
op->erase();
}
llvm::BitVector argsToErase(func.getNumArguments());
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
if (type.value().isa<NnModuleType>()) {
argsToErase.set(type.index());
}
}
func.eraseArguments(argsToErase);
return success(!walkResult.wasInterrupted());
}
static LogicalResult globalizeObjectGraph(ModuleOp module) {
// Step 1: Traverse object graph and collect information.
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule(module);
if (failed(maybeRootNnModule))
return failure();
NnModuleOp rootNnModule = *maybeRootNnModule;
ObjectGraphInfo objectGraphInfo(module);
if (failed(objectGraphInfo.initialize(rootNnModule)))
return failure();
DenseMap<ClassTypeOp, std::vector<NnModuleOp>> instances;
SymbolTable symbolTable(module);
for (auto nnModule : module.getOps<NnModuleOp>()) {
auto classType = nnModule.getClassType(symbolTable);
instances[classType].push_back(nnModule);
}
// Step 2: Create the torch.global_slot.module_initializer op.
if (failed(createGlobalSlotModuleInitializer(module, symbolTable,
objectGraphInfo)))
return failure();
// Step 3: Verify all functions are suitable to be analyzed by our later code.
// This eliminates special handling / error code later.
//
// This is important, because in principle, we can perform arbitrarily complex
// static analysis to discover how to monomorphize th eprogram, including
// tracking instances through control flow, through get/set attr, etc. We
// implement a very simple subset of cases.
for (auto func : module.getOps<func::FuncOp>()) {
if (failed(verifyFuncConformsToSubset(func)))
return failure();
}
// Step 4: Calculate the set of monomorphized functions that need to be
// created. For each call that passes !torch.nn.Module to a function, we need
// to create a specialized version of that function just for that instance (or
// combination of instances in the case of multiple arguments).
//
// At this stage, we only analyze which monomorphizations are needed and
// whether it is possible to monomorphize the program. The actual
// cloning/rewriting mechanics happen later.
//
// This lets us know which GlobalSlotOp we need to reference when we replace
// PrimSetAttrOp/PrimGetAttrOp.
//
// Note that in general there can be mutually recursive functions that
// re-enter themselves with a different set of instances -- the process of
// calculating these monomorphizations is a fixed-point iteration that
// discovers all needed monomorphizations. In practice this yields a
// controllable number.
MonomorphizationTracker tracker(module);
if (failed(tracker.initialize(instances)))
return failure();
if (failed(verifyPublicMonomorphizations(module, symbolTable, tracker))) {
return failure();
}
// Step 5: Clone/rewrite functions to implement the necessary
// monomorphizations.
DenseMap<Monomorphization, func::FuncOp> newFuncs;
int uniquifier = 0;
for (auto &monomorphization : tracker.getMonomorphizations()) {
auto newFunc = cast<func::FuncOp>(monomorphization.func->clone());
newFuncs[monomorphization] = newFunc;
std::optional<LinkageInfo> linkageInfo = std::nullopt;
// If it is potentially a method, check its linkage info.
if (monomorphization.argInstances.size() != 0 &&
monomorphization.argInstances[0].argIndex == 0) {
linkageInfo = objectGraphInfo.getFuncLinkageInfo(
monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(),
monomorphization.func);
}
if (linkageInfo.has_value()) {
// It's a method.
newFunc.setVisibility(linkageInfo->isPrivate
? SymbolTable::Visibility::Private
: SymbolTable::Visibility::Public);
newFunc.setName(linkageInfo->linkageName);
} else {
// It's a free function.
// TODO: Make the name nicer (no suffix in typical case).
newFunc.setName(
(Twine(newFunc.getName()) + "$" + Twine(uniquifier++)).str());
}
module.push_back(newFunc);
}
for (auto &kv : newFuncs) {
IRMapping mapping;
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
return failure();
if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable,
newFuncs, objectGraphInfo)))
return failure();
}
// Step 6: Clean up object graph.
DenseSet<func::FuncOp> liveFuncs;
for (auto &kv : newFuncs) {
liveFuncs.insert(kv.second);
}
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
if (isa<GlobalSlotOp, GlobalSlotModuleInitializerOp>(&op))
continue;
if (auto func = dyn_cast<func::FuncOp>(op)) {
if (liveFuncs.contains(func))
continue;
}
op.dropAllDefinedValueUses();
op.dropAllReferences();
op.erase();
}
return success();
}
namespace {
class GlobalizeObjectGraphPass
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
void runOnOperation() override {
if (failed(globalizeObjectGraph(getOperation())))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createGlobalizeObjectGraphPass() {
return std::make_unique<GlobalizeObjectGraphPass>();
}