mirror of https://github.com/llvm/torch-mlir
Support multiple instances of a class in GlobalizeObjectGraph.
This happens in practice with e.g. ResNet from torchvision (multiple instances of the same BatchNorm class). The key observation is that for this program, and the expected set of programs, we can convert the program to the same globalized form with a bit more static analysis and effort to suitably monomorphize the program. Though what we are doing here is fairly annoying to implement, it saves any nontrivial later pass from having to do similar analyses (or worse). E.g. shape inference would need to be object-graph aware, mutation/lifetime analyses would have to be aware, etc. Additionally, it would make us front-load what it means to have a !torch.nn.Module type on an ABI boundary, which we are just not ready to handle. I'm really, really hoping that in practice we can get away with this, otherwise it's going to be really rough designing a representation (and implementing everything to back it) that is convenient to transform and gracefully scales from full object graph (in the most dynamic case) down to a fixed set of global slots like we have here (in the most static case, which we presume a lot of practical programs fall into). This also involved introducing a `torch-prepare-for-globalize-object-graph` pass that does a minimal set of lowerings to simplify the IR into a more orthogonal and analyzable form, and a `torch-globalize-pipeline` helper. Recommended review order: - updated documentation in Passes.td - new tests in `globalize-object-graph-multiple-instances*.mlir` - implementation of GlobalizeObjectGraph.cpp - PrepareForGlobalizeObjectGraph.cpp + prepare-for-globalize-object-graph.mlir - misc stuff like torch-globalize-pipeline pipeline definition. With this, we can import, globalize, and inline resnet18 from torchvision: https://gist.github.com/silvasean/821586afc19b67d9fb72030b2e0adeb8pull/187/head
parent
2750d2084c
commit
58c7030104
|
@ -21,4 +21,52 @@
|
|||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h.inc"
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::SlotOp> {
|
||||
using SlotOp = ::mlir::NPCOMP::Torch::SlotOp;
|
||||
static SlotOp getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return SlotOp::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static SlotOp getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return SlotOp::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static unsigned getHashValue(SlotOp val) {
|
||||
return hash_value(val.getAsOpaquePointer());
|
||||
}
|
||||
static bool isEqual(SlotOp lhs, SlotOp rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::NnModuleOp> {
|
||||
using NnModuleOp = ::mlir::NPCOMP::Torch::NnModuleOp;
|
||||
static NnModuleOp getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return NnModuleOp::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static NnModuleOp getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return NnModuleOp::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static unsigned getHashValue(NnModuleOp val) {
|
||||
return hash_value(val.getAsOpaquePointer());
|
||||
}
|
||||
static bool isEqual(NnModuleOp lhs, NnModuleOp rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <> struct llvm::DenseMapInfo<::mlir::NPCOMP::Torch::ClassTypeOp> {
|
||||
using ClassTypeOp = ::mlir::NPCOMP::Torch::ClassTypeOp;
|
||||
static ClassTypeOp getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return ClassTypeOp::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static ClassTypeOp getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return ClassTypeOp::getFromOpaquePointer(pointer);
|
||||
}
|
||||
static unsigned getHashValue(ClassTypeOp val) {
|
||||
return hash_value(val.getAsOpaquePointer());
|
||||
}
|
||||
static bool isEqual(ClassTypeOp lhs, ClassTypeOp rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
#endif // NPCOMP_DIALECT_TORCH_IR_TORCHOPS_H
|
||||
|
|
|
@ -92,6 +92,9 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getClassName() { return getType().getClassName(); }
|
||||
ClassTypeOp getClassType(::mlir::SymbolTable &symbolTable) {
|
||||
return symbolTable.lookup<ClassTypeOp>(getClassName());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,13 @@ namespace Torch {
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createPrepareForGlobalizeObjectGraphPass();
|
||||
|
||||
/// Creates a pipeline that "globalizes" the given program.
|
||||
/// See the documentation on torch-globalize-object-graph for more details.
|
||||
void createGlobalizePipeline(OpPassManager &pm);
|
||||
|
||||
} // namespace Torch
|
||||
|
||||
/// Registers all Torch transformation passes.
|
||||
|
|
|
@ -38,7 +38,10 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
|||
|
||||
This pass performs a complete change of the externally visible calling
|
||||
convention of the MLIR module for a graph of objects and methods to a
|
||||
fixed set of globals and functions.
|
||||
fixed set of globals and functions. Additionally, method signatures are
|
||||
changed such that all types of !torch.nn.Module are deleted from public
|
||||
interfaces since they are guaranteed to correspond to a unique instance and
|
||||
are thus redundant.
|
||||
|
||||
Of course, only a subset of programs can be transformed, and this pass fails
|
||||
with an error if the conditions are violated.
|
||||
|
@ -49,11 +52,25 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
|||
- Rationale: Allows us to have a notion of a unique "root" op, which is
|
||||
used to define linkage. This also matches how TorchScript imports in
|
||||
practice (`torch.jit.script` imports a single root object).
|
||||
- There must be exactly one instance of each torch.class_type. Equivalently,
|
||||
Every torch.nn_module must have a distinct type.
|
||||
- Rationale: This guarantee precludes things like selecting between
|
||||
multiple modules dynamically at runtime, which would require indirecting
|
||||
between the separate storage of each instance.
|
||||
- Multiple instances of the same class type are allowed, as long as it is
|
||||
possible to monomorphize ("template instantiate") functions so that each
|
||||
argument of !torch.nn.Module type corresponds to a unique instance.
|
||||
In pratice, this limitation is either 1) (fundamental) due to truly
|
||||
dynamic use of modules, such as `m1 if cond() else m2` in Python code,
|
||||
or 2) (incidental) imprecision of the static analysis used in this pass
|
||||
which is used to calculate when a single intance is relevant. In general,
|
||||
this analysis is equivalent to the halting problem, but we can aim to
|
||||
improve this pass such that practical patterns are all handled.
|
||||
- Rationale: The fundamental limitation "1)" guarantees that the
|
||||
program can be lowered to a fixed set of globals without indirection
|
||||
across globals. In the absence of this property, most compiler
|
||||
analyses/transformations are significantly curtailed (or require very
|
||||
sophisticated implementations). For the moment, this restriction
|
||||
is deemed to be sufficiently reasonable to be a pragmatic choice to
|
||||
avoid front-loading the complexity of working with a representation that
|
||||
really does a good job of representing that kind of program.
|
||||
Additionally, it avoids front-loading the handling of programs which
|
||||
have !torch.nn.Module types at external calling convention boundaries.
|
||||
- All torch.nn_module's must be reachable by a unique path from the root
|
||||
- Rationale: Eliminates possibility of potentially exponential number of
|
||||
paths. Or worse, infinite number of paths when considering cyclic
|
||||
|
@ -70,4 +87,18 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def PrepareForGlobalizeObjectGraph
|
||||
: Pass<"torch-prepare-for-globalize-object-graph", "ModuleOp"> {
|
||||
let summary = "Lowering in preparation for globalizing";
|
||||
let constructor = "mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass()";
|
||||
let description = [{
|
||||
Establishes and the invariants needed by the
|
||||
torch-globalize-object-graph transformation. Fails if that cannot be
|
||||
accomplished.
|
||||
|
||||
Currently, this just involves ensuring a small set of patterns have been
|
||||
applied.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_TORCH_PASSES
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_npcomp_conversion_library(NPCOMPTorchPasses
|
||||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
PrepareForGlobalizeObjectGraph.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms
|
||||
|
|
|
@ -16,118 +16,14 @@
|
|||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
// See the pass documentation for `torch-globalize-object-graph`.
|
||||
class ObjectGraphGlobalizer {
|
||||
public:
|
||||
ObjectGraphGlobalizer(ModuleOp module);
|
||||
LogicalResult globalizeObjectGraph();
|
||||
|
||||
private:
|
||||
FailureOr<NnModuleOp> findRootNnModule();
|
||||
LogicalResult checkSingleInstanceOfEachClass();
|
||||
LogicalResult recursivelyTraverseClassType(ClassTypeOp classType);
|
||||
LogicalResult populateGlobalSlotInitializer(GlobalSlotOp op,
|
||||
Value initialValue);
|
||||
LogicalResult rewriteMethods();
|
||||
void removeObjectGraph();
|
||||
|
||||
ModuleOp module;
|
||||
SymbolTable symbolTable;
|
||||
OpBuilder globalBuilder;
|
||||
// The stack of attribute names we have traversed during our recursive
|
||||
// traversal of the class/object hierarchy.
|
||||
//
|
||||
// Linkage names are calculated based on the set of attribute names traversed
|
||||
// from the root class/module in the program.
|
||||
SmallVector<std::string> nameStack;
|
||||
|
||||
// Sometimes it is natural to want a map keyed on torch.attr ops or torch.slot
|
||||
// ops. However, usually it is better to keep a map keyed on an ClassTypeOp
|
||||
// + attr name since frequently that is all one has access to and it
|
||||
// would be tedious to scan the body of the ClassTypeOp for the torch.attr
|
||||
// with the corresponding name.
|
||||
using AttrOfClass =
|
||||
std::pair</*ClassTypeOp*/ Operation *, /*attr name*/ StringRef>;
|
||||
// The initial value associated with an attribute of a class.
|
||||
// Since we only allow a single instance of a class, this is equivalent to
|
||||
// the initial value of the unique slot corresponding to that attr.
|
||||
DenseMap<AttrOfClass, Value> slotInitialValues;
|
||||
// The inverse map of `slotInitialValues`.
|
||||
// Many attributes can have the same initial value, so the value type
|
||||
// is a vector.
|
||||
DenseMap<Value, std::vector<AttrOfClass>> slotInitialValuesInverseMap;
|
||||
|
||||
// The torch.global_slot corresponding to each torch.attr/torch.slot.
|
||||
DenseMap<AttrOfClass, GlobalSlotOp> globalSlotForAttr;
|
||||
// The linkage name (value) for the function with symbol name (key).
|
||||
DenseMap<StringRef, std::string> methodLinkageNames;
|
||||
|
||||
// The set of class types that have already been processed.
|
||||
// Used for diagnostics.
|
||||
// The map value is the original path from the root that we found it at.
|
||||
DenseMap</*ClassTypeOp*/ Operation *, std::string> seenClassTypes;
|
||||
|
||||
// A set of values that we have copied into torch.global_slot initializers,
|
||||
// which cannot be used in multiple initializers because their object
|
||||
// identity is important.
|
||||
DenseSet<Value> objectsWithIdentityAlreadyCopiedIntoInitializers;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ObjectGraphGlobalizer::ObjectGraphGlobalizer(ModuleOp module)
|
||||
: module(module), symbolTable(module),
|
||||
globalBuilder(module.getBodyRegion()) {}
|
||||
|
||||
LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() {
|
||||
// We require there to be a unique root !torch.nn.Module.
|
||||
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule();
|
||||
if (failed(maybeRootNnModule))
|
||||
return failure();
|
||||
NnModuleOp rootNnModule = *maybeRootNnModule;
|
||||
if (!rootNnModule)
|
||||
return module.emitError()
|
||||
<< "module does not contain a root torch.nn_module";
|
||||
|
||||
// We require one instance of each class. That is, there is a single
|
||||
// torch.nn_module for each torch.class_type.
|
||||
if (failed(checkSingleInstanceOfEachClass()))
|
||||
return failure();
|
||||
|
||||
for (NnModuleOp nnModule : module.getOps<NnModuleOp>()) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(nnModule.getClassName());
|
||||
for (auto slot : nnModule.getOps<SlotOp>()) {
|
||||
AttrOfClass attrOfClass = {classType, slot.name()};
|
||||
slotInitialValues[attrOfClass] = slot.value();
|
||||
slotInitialValuesInverseMap[slot.value()].push_back(attrOfClass);
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively traverse the class hierarchy, globalizing slots and
|
||||
// tracking linkage names for methods.
|
||||
auto rootClassType =
|
||||
symbolTable.lookup<ClassTypeOp>(rootNnModule.getClassName());
|
||||
if (failed(recursivelyTraverseClassType(rootClassType)))
|
||||
return failure();
|
||||
|
||||
// Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod.
|
||||
if (failed(rewriteMethods()))
|
||||
return failure();
|
||||
|
||||
// Now that all we have finished converting to the new form, remove
|
||||
// the original object graph.
|
||||
removeObjectGraph();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
|
||||
static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
|
||||
NnModuleOp rootNnModule;
|
||||
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
||||
if (!op.use_empty())
|
||||
|
@ -142,132 +38,354 @@ FailureOr<NnModuleOp> ObjectGraphGlobalizer::findRootNnModule() {
|
|||
}
|
||||
rootNnModule = op;
|
||||
}
|
||||
if (!rootNnModule) {
|
||||
module.emitError() << "module does not contain a root torch.nn_module";
|
||||
return failure();
|
||||
}
|
||||
return rootNnModule;
|
||||
}
|
||||
|
||||
LogicalResult ObjectGraphGlobalizer::checkSingleInstanceOfEachClass() {
|
||||
llvm::MapVector</*ClassTypeOp*/ Operation *, std::vector<NnModuleOp>>
|
||||
classInstances;
|
||||
for (NnModuleOp op : module.getOps<NnModuleOp>()) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(op.getClassName());
|
||||
classInstances[classType].push_back(op);
|
||||
}
|
||||
for (auto &p : classInstances) {
|
||||
ClassTypeOp classType = cast<ClassTypeOp>(p.first);
|
||||
ArrayRef<NnModuleOp> instances = p.second;
|
||||
if (instances.size() > 1) {
|
||||
// TODO: Improve this diagnostic based on user use cases.
|
||||
// This is a user-facing diagnostic that enforces key invariants to
|
||||
// our TorchScript subset.
|
||||
auto diag = classType.emitError(
|
||||
"class type has more than one instance: the current TorchScript "
|
||||
"supported subset only allows single instances");
|
||||
for (NnModuleOp instance : instances) {
|
||||
diag.attachNote(instance.getLoc()) << "see instance here";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) {
|
||||
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
|
||||
if (!seenClassTypes.insert({classType, pathToClassFromRoot}).second) {
|
||||
return classType.emitError()
|
||||
<< "reachable by multiple paths from root object: '<root>."
|
||||
<< seenClassTypes[classType] << "' and '<root>."
|
||||
<< pathToClassFromRoot << "'";
|
||||
}
|
||||
|
||||
// For each attr, create a global slot for it.
|
||||
for (auto attr : classType.getOps<AttrOp>()) {
|
||||
nameStack.push_back(attr.name().str());
|
||||
if (auto type = attr.type().dyn_cast<NnModuleType>()) {
|
||||
if (failed(recursivelyTraverseClassType(
|
||||
symbolTable.lookup<ClassTypeOp>(type.getClassName()))))
|
||||
return failure();
|
||||
} else {
|
||||
auto linkageName = llvm::join(nameStack, ".");
|
||||
auto globalSlot = globalBuilder.create<GlobalSlotOp>(
|
||||
attr->getLoc(), linkageName, /*sym_visibility=*/nullptr,
|
||||
attr.type());
|
||||
if (attr.isPrivate())
|
||||
globalSlot.setVisibility(SymbolTable::Visibility::Private);
|
||||
AttrOfClass attrOfClass = {classType, attr.name()};
|
||||
assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end());
|
||||
globalSlotForAttr[attrOfClass] = globalSlot;
|
||||
if (failed(populateGlobalSlotInitializer(globalSlot,
|
||||
slotInitialValues[attrOfClass])))
|
||||
return failure();
|
||||
}
|
||||
nameStack.pop_back();
|
||||
}
|
||||
// For each method, track the linkage name it will eventually have.
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
nameStack.push_back(method.name().str());
|
||||
auto linkageName = llvm::join(nameStack, ".");
|
||||
nameStack.pop_back();
|
||||
if (!methodLinkageNames.insert({method.function(), linkageName}).second)
|
||||
return method.emitError()
|
||||
<< "unbound function shared by multiple methods";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool hasMeaningfulObjectIdentity(Type type) {
|
||||
return !type.isa<IntegerType, FloatType, Basicpy::BoolType,
|
||||
Basicpy::BytesType, TensorType>();
|
||||
Basicpy::BytesType, Basicpy::NoneType, TensorType>();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ObjectGraphGlobalizer::populateGlobalSlotInitializer(GlobalSlotOp globalSlot,
|
||||
Value initialValue) {
|
||||
OpBuilder builder(globalSlot.getContext());
|
||||
builder.createBlock(&globalSlot.getRegion());
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Object graph recursive traversal.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SmallPtrSet<Operation *, 6> needToClone;
|
||||
SmallVector<Operation *> worklist = {initialValue.getDefiningOp()};
|
||||
while (!worklist.empty()) {
|
||||
Operation *op = worklist.pop_back_val();
|
||||
if (!needToClone.insert(op).second)
|
||||
continue;
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (auto def = operand.getDefiningOp())
|
||||
worklist.push_back(def);
|
||||
}
|
||||
namespace {
|
||||
struct LinkageInfo {
|
||||
std::string linkageName;
|
||||
bool isPrivate;
|
||||
};
|
||||
} // namespace
|
||||
namespace {
|
||||
/// Calculates the linkage names of all the potentially exported objects in the
|
||||
/// module and also creates GlobalSlotOp's for each SlotOp and tracks their
|
||||
/// associations.
|
||||
///
|
||||
/// The mechanics of both of these tasks involve the same object graph
|
||||
/// traversal, so it's useful to roll them together.
|
||||
class ObjectGraphInfo {
|
||||
public:
|
||||
ObjectGraphInfo(ModuleOp module)
|
||||
: globalSlotBuilder(module.getBodyRegion()), symbolTable(module) {}
|
||||
|
||||
LogicalResult initialize(NnModuleOp root) {
|
||||
return recursivelyTraverse(root);
|
||||
}
|
||||
worklist.assign(needToClone.begin(), needToClone.end());
|
||||
llvm::sort(worklist, [](Operation *lhs, Operation *rhs) {
|
||||
return lhs->isBeforeInBlock(rhs);
|
||||
});
|
||||
BlockAndValueMapping mapping;
|
||||
for (Operation *op : worklist) {
|
||||
builder.clone(*op, mapping);
|
||||
for (Value result : op->getResults()) {
|
||||
if (!hasMeaningfulObjectIdentity(result.getType()))
|
||||
|
||||
LinkageInfo getSlotLinkageInfo(SlotOp op) {
|
||||
auto it = slotLinkageInfo.find(op);
|
||||
assert(it != slotLinkageInfo.end());
|
||||
return it->second;
|
||||
}
|
||||
Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
|
||||
FuncOp methodFunc) {
|
||||
auto it = funcLinkageInfo.find({instance, methodFunc});
|
||||
if (it == funcLinkageInfo.end())
|
||||
return None;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
GlobalSlotOp getGlobalSlotFor(SlotOp slot) {
|
||||
auto it = slotToGlobalSlot.find(slot);
|
||||
assert(it != slotToGlobalSlot.end() && "didn't create global slot");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult recursivelyTraverse(NnModuleOp nnModule) {
|
||||
std::string pathToClassFromRoot = llvm::join(nameStack, ".");
|
||||
if (!seenNnModules.insert({nnModule, pathToClassFromRoot}).second) {
|
||||
return nnModule.emitError()
|
||||
<< "reachable by multiple paths from root object: '<root>."
|
||||
<< seenNnModules[nnModule] << "' and '<root>."
|
||||
<< pathToClassFromRoot << "'";
|
||||
}
|
||||
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
nnModule.getType().cast<NnModuleType>().getClassName());
|
||||
for (auto t :
|
||||
llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) {
|
||||
auto slot = std::get<0>(t);
|
||||
auto attr = std::get<1>(t);
|
||||
nameStack.push_back(attr.name().str());
|
||||
if (attr.type().isa<NnModuleType>()) {
|
||||
if (failed(
|
||||
recursivelyTraverse(slot.value().getDefiningOp<NnModuleOp>())))
|
||||
return failure();
|
||||
} else {
|
||||
std::string linkageName = llvm::join(nameStack, ".");
|
||||
auto globalSlot = globalSlotBuilder.create<GlobalSlotOp>(
|
||||
slot.getLoc(), linkageName,
|
||||
/*sym_visibility=*/nullptr, attr.type());
|
||||
if (attr.isPrivate())
|
||||
globalSlot.setVisibility(SymbolTable::Visibility::Private);
|
||||
assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end());
|
||||
slotToGlobalSlot[slot] = globalSlot;
|
||||
slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.isPrivate()};
|
||||
if (failed(populateGlobalSlotInitializer(globalSlot, slot.value())))
|
||||
return failure();
|
||||
}
|
||||
nameStack.pop_back();
|
||||
}
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
nameStack.push_back(method.name().str());
|
||||
funcLinkageInfo[{nnModule,
|
||||
symbolTable.lookup<FuncOp>(method.function())}] =
|
||||
LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()};
|
||||
nameStack.pop_back();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
LogicalResult populateGlobalSlotInitializer(GlobalSlotOp globalSlot,
|
||||
Value initialValue) {
|
||||
OpBuilder builder(globalSlot.getContext());
|
||||
builder.createBlock(&globalSlot.getRegion());
|
||||
|
||||
SmallPtrSet<Operation *, 6> needToClone;
|
||||
SmallVector<Operation *> worklist = {initialValue.getDefiningOp()};
|
||||
while (!worklist.empty()) {
|
||||
Operation *op = worklist.pop_back_val();
|
||||
if (!needToClone.insert(op).second)
|
||||
continue;
|
||||
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
|
||||
.second) {
|
||||
return op->emitError()
|
||||
<< "potentially-aliased value used to initialize multiple slots";
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (auto def = operand.getDefiningOp())
|
||||
worklist.push_back(def);
|
||||
}
|
||||
}
|
||||
worklist.assign(needToClone.begin(), needToClone.end());
|
||||
llvm::sort(worklist, [](Operation *lhs, Operation *rhs) {
|
||||
return lhs->isBeforeInBlock(rhs);
|
||||
});
|
||||
BlockAndValueMapping mapping;
|
||||
for (Operation *op : worklist) {
|
||||
builder.clone(*op, mapping);
|
||||
for (Value result : op->getResults()) {
|
||||
if (!hasMeaningfulObjectIdentity(result.getType()))
|
||||
continue;
|
||||
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
|
||||
.second) {
|
||||
return op->emitError() << "potentially-aliased value used to "
|
||||
"initialize multiple slots";
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.create<GlobalSlotInitOp>(globalSlot->getLoc(),
|
||||
mapping.lookup(initialValue));
|
||||
return success();
|
||||
}
|
||||
builder.create<GlobalSlotInitOp>(globalSlot->getLoc(),
|
||||
mapping.lookup(initialValue));
|
||||
return success();
|
||||
// Builder for creating GlobalSlotOp's in the module.
|
||||
OpBuilder globalSlotBuilder;
|
||||
// Symbol table for the module.
|
||||
SymbolTable symbolTable;
|
||||
// The set of NnModuleOp's that have already been processed.
|
||||
// Used for diagnostics.
|
||||
// The map value is the original path from the root that we found it at.
|
||||
DenseMap<NnModuleOp, std::string> seenNnModules;
|
||||
|
||||
// The stack of attribute names we have traversed during our recursive
|
||||
// traversal of the class/object hierarchy.
|
||||
//
|
||||
// Linkage names are calculated based on the set of attribute names traversed
|
||||
// from the root class/module in the program.
|
||||
std::vector<std::string> nameStack;
|
||||
// Linkage info for each SlotOp in the program.
|
||||
DenseMap<SlotOp, LinkageInfo> slotLinkageInfo;
|
||||
// Linkage info for each method in the program. Since we are going to be
|
||||
// monomorphizing all the functions, we also need to key this off of the
|
||||
// instance (NnModuleOp) that the func is monomorphized for.
|
||||
DenseMap<std::pair<NnModuleOp, FuncOp>, LinkageInfo> funcLinkageInfo;
|
||||
// The corresponding GlobalSlotOp for each SlotOp in the program.
|
||||
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
|
||||
// A set of values that we have copied into torch.global_slot initializers,
|
||||
// which cannot be used in multiple initializers because their object
|
||||
// identity is important.
|
||||
DenseSet<Value> objectsWithIdentityAlreadyCopiedIntoInitializers;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Monomorphization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// When used in an Monomorphization, indicates that the arg at `argIndex` will
|
||||
// correspond to instance `instance.
|
||||
struct ArgInstance {
|
||||
int argIndex;
|
||||
Value instance; // Result of an NnModuleOp.
|
||||
};
|
||||
static llvm::hash_code hash_value(const ArgInstance &argInstance) {
|
||||
return llvm::hash_combine(argInstance.argIndex, argInstance.instance);
|
||||
}
|
||||
static bool operator==(const ArgInstance &lhs, const ArgInstance &rhs) {
|
||||
return std::make_tuple(lhs.argIndex, lhs.instance) ==
|
||||
std::make_tuple(rhs.argIndex, rhs.instance);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Record indicating that a particular function must be monomorphized for the
|
||||
// given ArgInstance's, which involves deleting those arguments and specializing
|
||||
// all their uses to operate on GlobalSlotOp's that we have created for the
|
||||
// SlotOp's of the NnModuleOp instances.
|
||||
//
|
||||
// NOTE: Unlike the more traditional use of monomorphization to mean a single
|
||||
// *type* is being specialized for, here we are specializing for a specific
|
||||
// *instance*. This still fits the definition of monomorphization though, albeit
|
||||
// with each instance being considered to have a maximally refined type which is
|
||||
// a set with a single element (just this instance). This does not correspond to
|
||||
// any notion of "type" that we have in the IR, but still fits the formal
|
||||
// definition.
|
||||
struct Monomorphization {
|
||||
FuncOp func;
|
||||
std::vector<ArgInstance> argInstances;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <> struct llvm::DenseMapInfo<Monomorphization> {
|
||||
static Monomorphization getEmptyKey() {
|
||||
return Monomorphization{nullptr, {ArgInstance{-1, nullptr}}};
|
||||
}
|
||||
static Monomorphization getTombstoneKey() {
|
||||
return Monomorphization{nullptr, {ArgInstance{-2, nullptr}}};
|
||||
}
|
||||
static unsigned getHashValue(Monomorphization val) {
|
||||
return llvm::hash_combine(val.func.getAsOpaquePointer(),
|
||||
llvm::hash_combine_range(val.argInstances.begin(),
|
||||
val.argInstances.end()));
|
||||
}
|
||||
static bool isEqual(Monomorphization lhs, Monomorphization rhs) {
|
||||
return lhs.func == rhs.func &&
|
||||
std::equal(lhs.argInstances.begin(), lhs.argInstances.end(),
|
||||
rhs.argInstances.begin(), rhs.argInstances.end());
|
||||
}
|
||||
};
|
||||
|
||||
// Populate `mapping` such that values of NnModuleType in the function are
|
||||
// mapped to appropriate global objects of NnModuleType.
|
||||
//
|
||||
// This generalizes to a full abstract interpretation of the function, but
|
||||
// currently only analyzes a subset of ops.
|
||||
static LogicalResult analyzeInstances(FuncOp func,
|
||||
ArrayRef<ArgInstance> argInstances,
|
||||
BlockAndValueMapping &mapping) {
|
||||
for (auto &argInstance : argInstances)
|
||||
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
|
||||
auto walkResult = func.walk([&](PrimGetAttrOp op) {
|
||||
if (!op.getType().isa<NnModuleType>())
|
||||
return WalkResult::advance();
|
||||
auto instance = mapping.lookupOrNull(op.receiver());
|
||||
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
||||
for (auto slot : instance.getDefiningOp<NnModuleOp>().getOps<SlotOp>()) {
|
||||
if (slot.name() == op.name()) {
|
||||
mapping.map(op, slot.value());
|
||||
break;
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return success(!walkResult.wasInterrupted());
|
||||
}
|
||||
|
||||
static FailureOr<Monomorphization>
|
||||
createMonomorphizationForCall(CallOp op, BlockAndValueMapping &mapping,
|
||||
SymbolTable &symbolTable) {
|
||||
auto func = symbolTable.lookup<FuncOp>(op.callee());
|
||||
Monomorphization monomorphization;
|
||||
monomorphization.func = func;
|
||||
for (auto operand : llvm::enumerate(op->getOperands())) {
|
||||
if (!operand.value().getType().isa<NnModuleType>())
|
||||
continue;
|
||||
Value instance = mapping.lookupOrNull(operand.value());
|
||||
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
||||
monomorphization.argInstances.push_back(
|
||||
ArgInstance{static_cast<int>(operand.index()), instance});
|
||||
}
|
||||
return monomorphization;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class MonomorphizationTracker {
|
||||
public:
|
||||
MonomorphizationTracker(ModuleOp module)
|
||||
: module(module), symbolTable(module) {}
|
||||
LogicalResult
|
||||
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
Monomorphization monomorphization;
|
||||
monomorphization.func = func;
|
||||
bool canTriviallyMonomorphize = true;
|
||||
for (auto arg : llvm::enumerate(func.getArguments())) {
|
||||
auto type = arg.value().getType().dyn_cast<NnModuleType>();
|
||||
if (!type)
|
||||
continue;
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName());
|
||||
auto &classTypeInstances = instances[classType];
|
||||
if (classTypeInstances.size() != 1) {
|
||||
canTriviallyMonomorphize = false;
|
||||
break;
|
||||
}
|
||||
monomorphization.argInstances.push_back(
|
||||
{static_cast<int>(arg.index()), classTypeInstances[0]});
|
||||
}
|
||||
|
||||
if (canTriviallyMonomorphize) {
|
||||
dirtyMonomorphizations.push_back(monomorphization);
|
||||
monomorphizations.insert(monomorphization);
|
||||
}
|
||||
}
|
||||
while (!dirtyMonomorphizations.empty()) {
|
||||
Monomorphization dirty = dirtyMonomorphizations.pop_back_val();
|
||||
if (failed(generateNewMonomorphizations(dirty)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::SetVector<Monomorphization> &getMonomorphizations() {
|
||||
return monomorphizations;
|
||||
}
|
||||
|
||||
private:
|
||||
LogicalResult generateNewMonomorphizations(const Monomorphization &m) {
|
||||
auto func = m.func;
|
||||
BlockAndValueMapping mapping;
|
||||
if (failed(analyzeInstances(func, m.argInstances, mapping)))
|
||||
return failure();
|
||||
auto walkResult = func.walk([&](CallOp op) {
|
||||
FailureOr<Monomorphization> maybeMonomorphization =
|
||||
createMonomorphizationForCall(op, mapping, symbolTable);
|
||||
if (failed(maybeMonomorphization))
|
||||
return WalkResult::interrupt();
|
||||
if (monomorphizations.insert(*maybeMonomorphization))
|
||||
dirtyMonomorphizations.push_back(*maybeMonomorphization);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return success(!walkResult.wasInterrupted());
|
||||
}
|
||||
|
||||
ModuleOp module;
|
||||
SymbolTable symbolTable;
|
||||
SmallVector<Monomorphization> dirtyMonomorphizations;
|
||||
llvm::SetVector<Monomorphization> monomorphizations;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Verify that a value conforms to the subset of allowed uses for
|
||||
// !torch.nn.Module<"..."> types.
|
||||
static LogicalResult verifyNnModuleValueUses(Value value) {
|
||||
// Trivially true for non-module types.
|
||||
// Trivially succeed for non-module types.
|
||||
if (!value.getType().isa<NnModuleType>())
|
||||
return success();
|
||||
for (Operation *op : value.getUsers()) {
|
||||
if (isa<CallOp, PrimGetAttrOp, PrimSetAttrOp, PrimCallMethodOp>(op))
|
||||
if (isa<CallOp, PrimGetAttrOp>(op))
|
||||
continue;
|
||||
// Only allow `value` as the receiver.
|
||||
if (isa<PrimSetAttrOp>(op) && cast<PrimSetAttrOp>(op).value() != value)
|
||||
continue;
|
||||
// TODO: Improve this based on real user use cases.
|
||||
// This is a diagnostic that users will hit if they do not conform to
|
||||
|
@ -278,173 +396,254 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static std::string getNonMethodMangledFunctionName(StringRef originalName) {
|
||||
return "__npcomp_priv_fn$" + originalName.str();
|
||||
}
|
||||
|
||||
// Verify that `func` conforms to the subset of allowable method bodies
|
||||
// that we can convert.
|
||||
static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
|
||||
auto walkResult = func.walk([](Block *block) {
|
||||
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
|
||||
LogicalResult ret = success();
|
||||
func.walk([&](Block *block) {
|
||||
for (Value arg : block->getArguments()) {
|
||||
if (failed(verifyNnModuleValueUses(arg)))
|
||||
if (failed(verifyNnModuleValueUses(arg))) {
|
||||
ret = failure();
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
for (Operation &op : *block) {
|
||||
for (Value result : op.getResults()) {
|
||||
if (failed(verifyNnModuleValueUses(result)))
|
||||
if (failed(verifyNnModuleValueUses(result))) {
|
||||
ret = failure();
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return failure(walkResult.wasInterrupted());
|
||||
return ret;
|
||||
}
|
||||
|
||||
LogicalResult ObjectGraphGlobalizer::rewriteMethods() {
|
||||
DenseMap<AttrOfClass, StringRef> linkageNames;
|
||||
static LogicalResult
|
||||
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
||||
MonomorphizationTracker &tracker) {
|
||||
DenseMap<FuncOp, int> numMonomorphizations;
|
||||
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
||||
numMonomorphizations[monomorphization.func] += 1;
|
||||
}
|
||||
bool sawError = false;
|
||||
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
auto it = methodLinkageNames.find(method.function());
|
||||
if (it == methodLinkageNames.end())
|
||||
continue;
|
||||
linkageNames[{classType, method.name()}] = it->second;
|
||||
}
|
||||
}
|
||||
// We only handle a small subset of ops that conform with the set of
|
||||
// assumptions that allow us to globalize the object graph. Anything that
|
||||
// tries to treat modules as bona-fide objects and not just namespaces
|
||||
// of methods with a single instance of the corresponding type just gets
|
||||
// arbitrarily tricky to rewrite. E.g. what if the user creates a list
|
||||
// of modules, or there is an scf.if selecting between modules, etc.
|
||||
SmallVector<Operation *> toErase;
|
||||
auto rewriteOpWithNnModuleTypeOperand = [&](Operation *op) {
|
||||
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op)) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
primSetAttr.receiver().getType().cast<NnModuleType>().getClassName());
|
||||
auto globalSlot = globalSlotForAttr[{classType, primSetAttr.name()}];
|
||||
OpBuilder(primSetAttr)
|
||||
.create<GlobalSlotSetOp>(primSetAttr.getLoc(), globalSlot.sym_name(),
|
||||
primSetAttr.value());
|
||||
toErase.push_back(primSetAttr);
|
||||
} else if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op)) {
|
||||
// If the return value is NnModuleType, then we don't need to do anything.
|
||||
// Our verification earlier ensured that there are no uses that
|
||||
// we won't properly rewrite.
|
||||
if (!primGetAttr.getType().isa<NnModuleType>()) {
|
||||
auto classType =
|
||||
symbolTable.lookup<ClassTypeOp>(primGetAttr.receiver()
|
||||
.getType()
|
||||
.cast<NnModuleType>()
|
||||
.getClassName());
|
||||
auto globalSlot = globalSlotForAttr[{classType, primGetAttr.name()}];
|
||||
auto globalSlotGet =
|
||||
OpBuilder(primGetAttr)
|
||||
.create<GlobalSlotGetOp>(primGetAttr.getLoc(),
|
||||
primGetAttr.getType(),
|
||||
globalSlot.sym_name());
|
||||
primGetAttr.replaceAllUsesWith(globalSlotGet.getOperation());
|
||||
if (!method.isPrivate()) {
|
||||
if (numMonomorphizations[symbolTable.lookup<FuncOp>(
|
||||
method.function())] > 1) {
|
||||
method.emitError()
|
||||
<< "public function with multiple monomorphizations";
|
||||
sawError = true;
|
||||
}
|
||||
}
|
||||
toErase.push_back(primGetAttr);
|
||||
} else if (auto primCallMethod = dyn_cast<PrimCallMethodOp>(op)) {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(primCallMethod.receiver()
|
||||
.getType()
|
||||
.cast<NnModuleType>()
|
||||
.getClassName());
|
||||
StringRef linkageName = linkageNames[{classType, primCallMethod.name()}];
|
||||
|
||||
auto newOperands = llvm::to_vector<6>(
|
||||
llvm::make_filter_range(primCallMethod.operands(), [](Value v) {
|
||||
return !v.getType().isa<NnModuleType>();
|
||||
}));
|
||||
auto call = OpBuilder(primCallMethod)
|
||||
.create<CallOp>(primCallMethod.getLoc(), linkageName,
|
||||
primCallMethod.getType(), newOperands);
|
||||
primCallMethod.replaceAllUsesWith(call);
|
||||
toErase.push_back(primCallMethod);
|
||||
} else if (auto callOp = dyn_cast<CallOp>(op)) {
|
||||
auto newOperands = llvm::to_vector<6>(
|
||||
llvm::make_filter_range(callOp.operands(), [](Value v) {
|
||||
return !v.getType().isa<NnModuleType>();
|
||||
}));
|
||||
auto newCallOp = OpBuilder(callOp).create<CallOp>(
|
||||
callOp.getLoc(), getNonMethodMangledFunctionName(callOp.callee()),
|
||||
callOp.getResultTypes(), newOperands);
|
||||
callOp.replaceAllUsesWith(newCallOp);
|
||||
toErase.push_back(callOp);
|
||||
}
|
||||
};
|
||||
struct MethodFuncRewrite {
|
||||
bool isPrivate;
|
||||
std::string linkageName;
|
||||
};
|
||||
|
||||
DenseMap<FuncOp, MethodFuncRewrite> methodFuncRewrites;
|
||||
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
auto it = methodLinkageNames.find(method.function());
|
||||
if (it == methodLinkageNames.end())
|
||||
continue;
|
||||
FuncOp func = symbolTable.lookup<FuncOp>(method.function());
|
||||
methodFuncRewrites[func] =
|
||||
MethodFuncRewrite{method.isPrivate(), it->second};
|
||||
}
|
||||
}
|
||||
return success(!sawError);
|
||||
}
|
||||
|
||||
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
|
||||
// `mapping` to corresponding global instances.
|
||||
static LogicalResult
|
||||
rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
|
||||
SymbolTable &symbolTable,
|
||||
DenseMap<Monomorphization, FuncOp> &newFuncs,
|
||||
ObjectGraphInfo &objectGraphInfo) {
|
||||
|
||||
SmallVector<Operation *> toErase;
|
||||
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
|
||||
auto instance = mapping.lookup(op.receiver()).getDefiningOp<NnModuleOp>();
|
||||
SlotOp affectedSlot;
|
||||
for (auto slot : instance.getOps<SlotOp>()) {
|
||||
if (slot.name() == op.name())
|
||||
affectedSlot = slot;
|
||||
}
|
||||
OpBuilder(op).create<GlobalSlotSetOp>(
|
||||
op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).sym_name(),
|
||||
op.value());
|
||||
toErase.push_back(op);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
|
||||
if (!op.getType().isa<NnModuleType>()) {
|
||||
auto instance = mapping.lookup(op.receiver()).getDefiningOp<NnModuleOp>();
|
||||
SlotOp affectedSlot;
|
||||
for (auto slot : instance.getOps<SlotOp>()) {
|
||||
if (slot.name() == op.name())
|
||||
affectedSlot = slot;
|
||||
}
|
||||
auto newOp = OpBuilder(op).create<GlobalSlotGetOp>(
|
||||
op.getLoc(), op.getType(),
|
||||
objectGraphInfo.getGlobalSlotFor(affectedSlot).sym_name());
|
||||
op.replaceAllUsesWith(&*newOp);
|
||||
}
|
||||
toErase.push_back(op);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto handleCall = [&](CallOp op) {
|
||||
FailureOr<Monomorphization> maybeMonomorphization =
|
||||
createMonomorphizationForCall(op, mapping, symbolTable);
|
||||
if (failed(maybeMonomorphization))
|
||||
return WalkResult::interrupt();
|
||||
Monomorphization monomorphization = std::move(*maybeMonomorphization);
|
||||
auto newArguments = llvm::to_vector<6>(
|
||||
llvm::make_filter_range(op->getOperands(), [](Value v) {
|
||||
return !v.getType().isa<NnModuleType>();
|
||||
}));
|
||||
assert(newFuncs.find(monomorphization) != newFuncs.end());
|
||||
auto newOp = OpBuilder(op).create<CallOp>(
|
||||
op.getLoc(), newFuncs[monomorphization], newArguments);
|
||||
op.replaceAllUsesWith(newOp);
|
||||
toErase.push_back(op);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto walkResult = func.walk([&](Operation *op) {
|
||||
if (auto primSetAttr = dyn_cast<PrimSetAttrOp>(op))
|
||||
return handlePrimSetAttr(primSetAttr);
|
||||
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op))
|
||||
return handlePrimGetAttr(primGetAttr);
|
||||
if (auto call = dyn_cast<CallOp>(op))
|
||||
return handleCall(call);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
for (auto op : toErase) {
|
||||
op->dropAllDefinedValueUses();
|
||||
op->erase();
|
||||
}
|
||||
SmallVector<unsigned> argsToErase;
|
||||
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||
if (type.value().isa<NnModuleType>()) {
|
||||
argsToErase.push_back(type.index());
|
||||
}
|
||||
}
|
||||
func.eraseArguments(argsToErase);
|
||||
return success(!walkResult.wasInterrupted());
|
||||
}
|
||||
|
||||
static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||
|
||||
// Step 1: Traverse object graph and collect information.
|
||||
|
||||
FailureOr<NnModuleOp> maybeRootNnModule = findRootNnModule(module);
|
||||
if (failed(maybeRootNnModule))
|
||||
return failure();
|
||||
NnModuleOp rootNnModule = *maybeRootNnModule;
|
||||
ObjectGraphInfo objectGraphInfo(module);
|
||||
if (failed(objectGraphInfo.initialize(rootNnModule)))
|
||||
return failure();
|
||||
|
||||
DenseMap<ClassTypeOp, std::vector<NnModuleOp>> instances;
|
||||
SymbolTable symbolTable(module);
|
||||
for (auto nnModule : module.getOps<NnModuleOp>()) {
|
||||
auto classType = nnModule.getClassType(symbolTable);
|
||||
instances[classType].push_back(nnModule);
|
||||
}
|
||||
|
||||
// Step 2: Verify all functions are suitable to be analyzed by our later code.
|
||||
// This eliminates special handling / error code later.
|
||||
//
|
||||
// This is important, because in principle, we can perform arbitrarily complex
|
||||
// static analysis to discover how to monomorphize th eprogram, including
|
||||
// tracking instances through control flow, through get/set attr, etc. We
|
||||
// implement a very simple subset of cases.
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (failed(verifyFuncConformsToSubset(func)))
|
||||
return failure();
|
||||
func.walk(rewriteOpWithNnModuleTypeOperand);
|
||||
for (Operation *op : toErase) {
|
||||
op->dropAllDefinedValueUses();
|
||||
op->erase();
|
||||
}
|
||||
toErase.clear();
|
||||
SmallVector<unsigned> argsToErase;
|
||||
for (auto arg : llvm::enumerate(func.getArguments())) {
|
||||
if (!arg.value().getType().isa<NnModuleType>())
|
||||
continue;
|
||||
assert(arg.value().use_empty() && "all uses should have been removed");
|
||||
argsToErase.push_back(arg.index());
|
||||
}
|
||||
func.eraseArguments(argsToErase);
|
||||
// No need to handle the results, since we currently don't allow ReturnOp
|
||||
// as a user of module types.
|
||||
}
|
||||
|
||||
// Adjust the linkage names to adopt the linkage convention of this pass,
|
||||
// namely that only objects accessible from a root !torch.nn.Module are
|
||||
// possible external linkage candidates, and their linkage name is the
|
||||
// dotted path from the root.
|
||||
//
|
||||
// Any other function gets a prefix to avoid collisions. These other
|
||||
// functions correspond to free functions somewhere outside the module
|
||||
// hierarchy.
|
||||
auto it = methodFuncRewrites.find(func);
|
||||
if (it != methodFuncRewrites.end()) {
|
||||
if (!it->second.isPrivate)
|
||||
func.setVisibility(SymbolTable::Visibility::Public);
|
||||
func.setName(it->second.linkageName);
|
||||
} else {
|
||||
func.setName(getNonMethodMangledFunctionName(func.getName()));
|
||||
// Step 3: Calculate the set of monomorphized functions that need to be
|
||||
// created. For each call that passes !torch.nn.Module to a function, we need
|
||||
// to create a specialized version of that function just for that instance (or
|
||||
// combination of instances in the case of multiple arguments).
|
||||
//
|
||||
// At this stage, we only analyze which monomorphizations are needed and
|
||||
// whether it is possible to monomorphize the program. The actual
|
||||
// cloning/rewriting mechanics happen later.
|
||||
//
|
||||
// This lets us know which GlobalSlotOp we need to reference when we replace
|
||||
// PrimSetAttrOp/PrimGetAttrOp.
|
||||
//
|
||||
// Note that in general there can be mutually recursive functions that
|
||||
// re-enter themselves with a different set of instances -- the process of
|
||||
// calculating these monomorphizations is a fixed-point iteration that
|
||||
// discovers all needed monomorphizations. In practice this yields a
|
||||
// controllable number.
|
||||
MonomorphizationTracker tracker(module);
|
||||
if (failed(tracker.initialize(instances)))
|
||||
return failure();
|
||||
|
||||
if (failed(verifyPublicMonomorphizations(module, symbolTable, tracker))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Step 4: Clone/rewrite functions to implement the necessary
|
||||
// monomorphizations.
|
||||
DenseMap<Monomorphization, FuncOp> newFuncs;
|
||||
int uniquifier = 0;
|
||||
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
||||
auto newFunc = cast<FuncOp>(monomorphization.func->clone());
|
||||
newFuncs[monomorphization] = newFunc;
|
||||
Optional<LinkageInfo> linkageInfo = None;
|
||||
// If it is potentially a method, check its linkage info.
|
||||
if (monomorphization.argInstances.size() != 0 &&
|
||||
monomorphization.argInstances[0].argIndex == 0) {
|
||||
linkageInfo = objectGraphInfo.getFuncLinkageInfo(
|
||||
monomorphization.argInstances[0].instance.getDefiningOp<NnModuleOp>(),
|
||||
monomorphization.func);
|
||||
}
|
||||
if (linkageInfo.hasValue()) {
|
||||
// It's a method.
|
||||
newFunc.setVisibility(linkageInfo->isPrivate
|
||||
? SymbolTable::Visibility::Private
|
||||
: SymbolTable::Visibility::Public);
|
||||
newFunc.setName(linkageInfo->linkageName);
|
||||
} else {
|
||||
// It's a free function.
|
||||
// TODO: Make the name nicer (no suffix in typical case).
|
||||
newFunc.setName(
|
||||
(Twine(newFunc.getName()) + "$" + Twine(uniquifier++)).str());
|
||||
}
|
||||
module.push_back(newFunc);
|
||||
}
|
||||
|
||||
|
||||
|
||||
for (auto &kv : newFuncs) {
|
||||
BlockAndValueMapping mapping;
|
||||
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
|
||||
return failure();
|
||||
if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable,
|
||||
newFuncs, objectGraphInfo)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Step 5: Clean up object graph.
|
||||
DenseSet<FuncOp> liveFuncs;
|
||||
for (auto &kv : newFuncs) {
|
||||
liveFuncs.insert(kv.second);
|
||||
}
|
||||
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
|
||||
if (isa<GlobalSlotOp, ModuleTerminatorOp>(&op))
|
||||
continue;
|
||||
if (auto func = dyn_cast<FuncOp>(op)) {
|
||||
if (liveFuncs.contains(func))
|
||||
continue;
|
||||
}
|
||||
op.dropAllDefinedValueUses();
|
||||
op.dropAllReferences();
|
||||
op.erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ObjectGraphGlobalizer::removeObjectGraph() {
|
||||
for (Operation &op : llvm::make_early_inc_range(*module.getBody())) {
|
||||
if (!isa<FuncOp, GlobalSlotOp, ModuleTerminatorOp>(op)) {
|
||||
op.dropAllDefinedValueUses();
|
||||
op.erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GlobalizeObjectGraphPass
|
||||
: public GlobalizeObjectGraphBase<GlobalizeObjectGraphPass> {
|
||||
void runOnOperation() override {
|
||||
if (failed(ObjectGraphGlobalizer(getOperation()).globalizeObjectGraph()))
|
||||
if (failed(globalizeObjectGraph(getOperation())))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
|
@ -17,4 +18,14 @@ namespace {
|
|||
#include "npcomp/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::NPCOMP::registerTorchPasses() { ::registerPasses(); }
|
||||
void mlir::NPCOMP::registerTorchPasses() {
|
||||
::registerPasses();
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"torch-globalize-pipeline", "Globalization pipeline.",
|
||||
mlir::NPCOMP::Torch::createGlobalizePipeline);
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::Torch::createGlobalizePipeline(OpPassManager &pm) {
|
||||
pm.addPass(createPrepareForGlobalizeObjectGraphPass());
|
||||
pm.addPass(createGlobalizeObjectGraphPass());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
//===- PrepareForGlobalizeObjectGraph.cpp ------------------------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
class ConvertPrimCallMethodToCall : public OpRewritePattern<PrimCallMethodOp> {
|
||||
public:
|
||||
ConvertPrimCallMethodToCall(MLIRContext *context, SymbolTable &symbolTable)
|
||||
: OpRewritePattern(context), symbolTable(symbolTable) {}
|
||||
LogicalResult matchAndRewrite(PrimCallMethodOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
op.receiver().getType().cast<NnModuleType>().getClassName());
|
||||
FuncOp func;
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
if (method.name() == op.name()) {
|
||||
func = symbolTable.lookup<FuncOp>(method.function());
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(func);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, func, op->getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
SymbolTable &symbolTable;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class EraseUnusedConstantOp : public OpRewritePattern<ConstantOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ConstantOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class PrepareForGlobalizeObjectGraphPass
|
||||
: public PrepareForGlobalizeObjectGraphBase<
|
||||
PrepareForGlobalizeObjectGraphPass> {
|
||||
void runOnOperation() override {
|
||||
|
||||
SymbolTable symbolTable(getOperation());
|
||||
|
||||
MLIRContext *context = &getContext();
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertPrimCallMethodToCall>(context, symbolTable);
|
||||
CallIndirectOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.insert<EraseUnusedConstantOp>(context);
|
||||
|
||||
// Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
|
||||
// makes the ConstantOp unused, which does not work with the visitation
|
||||
// order of the dialect conversion infrastructure.
|
||||
// TODO: Do this with the dialect conversion infrastructure to avoid doing
|
||||
// folding as part of this. Or avoid folding during greedy pattern
|
||||
// application. See: https://llvm.org/PR49502
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Do a dummy full conversion to ensure that the program has been converted
|
||||
// to the form we want.
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<PrimCallMethodOp>();
|
||||
target.addDynamicallyLegalOp<ConstantOp>([](ConstantOp op) {
|
||||
return !op.getType().isa<FunctionType>();
|
||||
});
|
||||
target.addIllegalOp<CallIndirectOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
|
||||
OwningRewritePatternList dummyPatterns;
|
||||
|
||||
if (failed(applyFullConversion(getOperation(), target,
|
||||
std::move(dummyPatterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Torch::createPrepareForGlobalizeObjectGraphPass() {
|
||||
return std::make_unique<PrepareForGlobalizeObjectGraphPass>();
|
||||
}
|
|
@ -10,26 +10,6 @@ torch.nn_module {} : !torch.nn.Module<"c2">
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{class type has more than one instance: the current TorchScript supported subset only allows single instances}}
|
||||
torch.class_type @child {}
|
||||
torch.class_type @parent {
|
||||
torch.attr "m1" : !torch.nn.Module<"child">
|
||||
torch.attr "m2" : !torch.nn.Module<"child">
|
||||
}
|
||||
|
||||
// expected-note @+1 {{see instance here}}
|
||||
%0 = torch.nn_module {} : !torch.nn.Module<"child">
|
||||
// expected-note @+1 {{see instance here}}
|
||||
%1 = torch.nn_module {} : !torch.nn.Module<"child">
|
||||
|
||||
%root = torch.nn_module {
|
||||
torch.slot "m1", %0 : !torch.nn.Module<"child">
|
||||
torch.slot "m2", %1 : !torch.nn.Module<"child">
|
||||
} : !torch.nn.Module<"parent">
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{reachable by multiple paths from root object: '<root>.m' and '<root>.m2'}}
|
||||
torch.class_type @child {
|
||||
torch.attr "float" : f64
|
||||
}
|
||||
|
@ -40,6 +20,7 @@ torch.class_type @parent {
|
|||
}
|
||||
|
||||
%c42 = std.constant 42.0 : f64
|
||||
// expected-error @+1 {{reachable by multiple paths from root object: '<root>.m' and '<root>.m2'}}
|
||||
%child = torch.nn_module {
|
||||
torch.slot "float", %c42 : f64
|
||||
} : !torch.nn.Module<"child">
|
||||
|
|
|
@ -4,7 +4,8 @@ torch.class_type @c {
|
|||
torch.attr "float" : f64
|
||||
torch.method "calls_free_function", @calls_free_function
|
||||
}
|
||||
// CHECK-LABEL: func private @__npcomp_priv_fn$free_function(
|
||||
// CHECK-LABEL: func private
|
||||
// CHECK-SAME: @free_function$[[$MONOMORPHIZE_TAG0:.*]](
|
||||
// CHECK-SAME: %[[F:.*]]: f64) -> f64 {
|
||||
// CHECK: return %[[F]] : f64
|
||||
// CHECK: }
|
||||
|
@ -12,15 +13,26 @@ func private @free_function(%arg0: f64, %arg1: !torch.nn.Module<"c">) -> f64 {
|
|||
return %arg0 : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private
|
||||
// CHECK-SAME: @free_function_no_module_args$[[$MONOMORPHIZE_TAG1:.*]](
|
||||
// CHECK-SAME: %[[F:.*]]: f64) -> f64 {
|
||||
// CHECK: return %[[F]] : f64
|
||||
// CHECK: }
|
||||
func private @free_function_no_module_args(%arg0: f64) -> f64 {
|
||||
return %arg0 : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @calls_free_function() -> f64 {
|
||||
// CHECK: %[[F1:.*]] = torch.global_slot.get @float : f64
|
||||
// CHECK: %[[RET:.*]] = call @__npcomp_priv_fn$free_function(%[[F1]]) : (f64) -> f64
|
||||
// CHECK: %[[F2:.*]] = call @free_function$[[$MONOMORPHIZE_TAG0]](%[[F1]]) : (f64) -> f64
|
||||
// CHECK: %[[RET:.*]] = call @free_function_no_module_args$[[$MONOMORPHIZE_TAG1]](%[[F2]]) : (f64) -> f64
|
||||
// CHECK: return %[[RET]] : f64
|
||||
// CHECK: }
|
||||
func private @calls_free_function(%arg0: !torch.nn.Module<"c">) -> f64 {
|
||||
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> f64
|
||||
%1 = call @free_function(%0, %arg0) : (f64, !torch.nn.Module<"c">) -> f64
|
||||
return %1 : f64
|
||||
%2 = call @free_function_no_module_args(%1) : (f64) -> f64
|
||||
return %2 : f64
|
||||
}
|
||||
|
||||
%c42 = std.constant 42.0 : f64
|
||||
|
|
|
@ -29,7 +29,7 @@ func private @test_set(%arg0: !torch.nn.Module<"c">, %arg1: f64) {
|
|||
// CHECK: %[[V:.*]] = call @test_call(%[[A]]) : (f64) -> f64
|
||||
// CHECK: return %[[V]] : f64
|
||||
func private @test_call(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 {
|
||||
%0 = torch.prim.CallMethod %arg0["test_call"] (%arg1) : !torch.nn.Module<"c">, (f64) -> f64
|
||||
%0 = call @test_call(%arg0, %arg1) : (!torch.nn.Module<"c">, f64) -> f64
|
||||
return %0 : f64
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ func private @get_attr_returns_module_type(%arg0: !torch.nn.Module<"parent">) ->
|
|||
|
||||
// CHECK-LABEL: func @module_type_argument(
|
||||
// CHECK-SAME: %[[F:.*]]: f64) -> !basicpy.NoneType {
|
||||
func private @module_type_argument(%arg0: !torch.nn.Module<"parent">, %arg1: f64, %arg2: !torch.nn.Module<"parent">) -> !basicpy.NoneType {
|
||||
func private @module_type_argument(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"parent">, %arg2: f64, %arg3: !torch.nn.Module<"parent">) -> !basicpy.NoneType {
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
return %0 : !basicpy.NoneType
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func private @method_call(%arg0: !torch.nn.Module<"parent">) -> !basicpy.NoneTyp
|
|||
// CHECK-NEXT: %[[C:.*]] = constant 4.300000e+01 : f64
|
||||
%c = constant 43.0 : f64
|
||||
// CHECK-NEXT: %[[F:.*]] = call @module_type_argument(%[[C]]) : (f64) -> !basicpy.NoneType
|
||||
%0 = torch.prim.CallMethod %arg0["module_type_argument"] (%arg0, %c, %arg0) : !torch.nn.Module<"parent">, (!torch.nn.Module<"parent">, f64, !torch.nn.Module<"parent">) -> (!basicpy.NoneType)
|
||||
%0 = call @module_type_argument(%arg0, %arg0, %c, %arg0) : (!torch.nn.Module<"parent">, !torch.nn.Module<"parent">, f64, !torch.nn.Module<"parent">) -> (!basicpy.NoneType)
|
||||
// CHECK-NEXT: return %[[F]] : !basicpy.NoneType
|
||||
return %0 : !basicpy.NoneType
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -verify-diagnostics -split-input-file %s
|
||||
|
||||
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
|
||||
return
|
||||
}
|
||||
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
|
||||
%5 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
%6 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
call @__torch__.Submodule.forward(%5, %6) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
call @__torch__.Submodule.forward(%5, %5) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
return
|
||||
}
|
||||
torch.class_type @__torch__.TestModule {
|
||||
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.attr private "s2" : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.method "forward", @__torch__.TestModule.forward
|
||||
}
|
||||
%bool_true = basicpy.bool_constant true
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
torch.class_type @__torch__.Submodule {
|
||||
torch.attr private "n" : i64
|
||||
// expected-error @+1 {{public function with multiple monomorphizations}}
|
||||
torch.method "forward", @__torch__.Submodule.forward
|
||||
}
|
||||
%num1_i64 = basicpy.numeric_constant 1 : i64
|
||||
%1 = torch.nn_module {
|
||||
torch.slot "n", %num1_i64 : i64
|
||||
} : !torch.nn.Module<"__torch__.Submodule">
|
||||
%num2_i64 = basicpy.numeric_constant 2 : i64
|
||||
%2 = torch.nn_module {
|
||||
torch.slot "n", %num2_i64 : i64
|
||||
} : !torch.nn.Module<"__torch__.Submodule">
|
||||
%3 = torch.nn_module {
|
||||
torch.slot "s1", %1 : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.slot "s2", %2 : !torch.nn.Module<"__torch__.Submodule">
|
||||
} : !torch.nn.Module<"__torch__.TestModule">
|
|
@ -0,0 +1,74 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
// Tests monomorphization of same function with different instance argument types.
|
||||
|
||||
torch.class_type @__torch__.TestModule {
|
||||
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.attr private "s2" : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.method "forward", @__torch__.TestModule.forward
|
||||
}
|
||||
torch.class_type @__torch__.Submodule {
|
||||
torch.attr private "n" : i64
|
||||
torch.method private "forward", @__torch__.Submodule.forward
|
||||
}
|
||||
|
||||
%num1_i64 = basicpy.numeric_constant 1 : i64
|
||||
%s1 = torch.nn_module {
|
||||
// CHECK-LABEL: torch.global_slot "private" @s1.n : i64 {
|
||||
// CHECK: %[[C1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
// CHECK: torch.global_slot.init %[[C1]] : i64
|
||||
// CHECK: }
|
||||
torch.slot "n", %num1_i64 : i64
|
||||
} : !torch.nn.Module<"__torch__.Submodule">
|
||||
%num2_i64 = basicpy.numeric_constant 2 : i64
|
||||
%s2 = torch.nn_module {
|
||||
// CHECK-LABEL: torch.global_slot "private" @s2.n : i64 {
|
||||
// CHECK: %[[C2:.*]] = basicpy.numeric_constant 2 : i64
|
||||
// CHECK: torch.global_slot.init %[[C2]] : i64
|
||||
// CHECK: }
|
||||
torch.slot "n", %num2_i64 : i64
|
||||
} : !torch.nn.Module<"__torch__.Submodule">
|
||||
%3 = torch.nn_module {
|
||||
torch.slot "s1", %s1 : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.slot "s2", %s2 : !torch.nn.Module<"__torch__.Submodule">
|
||||
} : !torch.nn.Module<"__torch__.TestModule">
|
||||
|
||||
|
||||
// CHECK-LABEL: func @forward() {
|
||||
// CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG0:.*]]() : () -> ()
|
||||
// CHECK: call @__torch__.free_function$[[$MONOMORPHIZE_TAG1:.*]]() : () -> ()
|
||||
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
|
||||
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
call @__torch__.free_function(%4, %5) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
%7 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
%8 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
call @__torch__.free_function(%7, %8) : (!torch.nn.Module<"__torch__.Submodule">, !torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// s1 called first, then s2
|
||||
// CHECK-LABEL: func private
|
||||
// CHECK-SAME @__torch__.free_function$[[$MONOMORPHIZE_TAG0]]() {
|
||||
// CHECK: call @s1.forward() : () -> ()
|
||||
// CHECK: call @s2.forward() : () -> ()
|
||||
|
||||
// s2 called first, then s1
|
||||
// CHECK-LABEL: func private
|
||||
// CHECK-SAME: @__torch__.free_function$[[$MONOMORPHIZE_TAG1]]() {
|
||||
// CHECK: call @s2.forward() : () -> ()
|
||||
// CHECK: call @s1.forward() : () -> ()
|
||||
func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Submodule">, %arg1: !torch.nn.Module<"__torch__.Submodule">) {
|
||||
call @__torch__.Submodule.forward(%arg0) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
call @__torch__.Submodule.forward(%arg1) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @s2.forward() {
|
||||
// CHECK: return
|
||||
|
||||
// CHECK-LABEL: func private @s1.forward() {
|
||||
// CHECK: return
|
||||
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
|
||||
return
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @__torch__.TestModule {
|
||||
torch.attr private "s1" : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.attr private "s2" : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.method "forward", @__torch__.TestModule.forward
|
||||
}
|
||||
torch.class_type @__torch__.Submodule {
|
||||
torch.attr private "n" : i64
|
||||
torch.method private "forward", @__torch__.Submodule.forward
|
||||
}
|
||||
|
||||
%num1_i64 = basicpy.numeric_constant 1 : i64
|
||||
%s1 = torch.nn_module {
|
||||
// CHECK-LABEL: torch.global_slot "private" @s1.n : i64 {
|
||||
// CHECK: %[[C1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
// CHECK: torch.global_slot.init %[[C1]] : i64
|
||||
// CHECK: }
|
||||
torch.slot "n", %num1_i64 : i64
|
||||
} : !torch.nn.Module<"__torch__.Submodule">
|
||||
%num2_i64 = basicpy.numeric_constant 2 : i64
|
||||
%s2 = torch.nn_module {
|
||||
// CHECK-LABEL: torch.global_slot "private" @s2.n : i64 {
|
||||
// CHECK: %[[C2:.*]] = basicpy.numeric_constant 2 : i64
|
||||
// CHECK: torch.global_slot.init %[[C2]] : i64
|
||||
// CHECK: }
|
||||
torch.slot "n", %num2_i64 : i64
|
||||
} : !torch.nn.Module<"__torch__.Submodule">
|
||||
%3 = torch.nn_module {
|
||||
torch.slot "s1", %s1 : !torch.nn.Module<"__torch__.Submodule">
|
||||
torch.slot "s2", %s2 : !torch.nn.Module<"__torch__.Submodule">
|
||||
} : !torch.nn.Module<"__torch__.TestModule">
|
||||
|
||||
|
||||
// CHECK-LABEL: func @forward() {
|
||||
// CHECK: call @s1.forward() : () -> ()
|
||||
// CHECK: call @s2.forward() : () -> ()
|
||||
// CHECK: return
|
||||
func private @__torch__.TestModule.forward(%arg0: !torch.nn.Module<"__torch__.TestModule">) {
|
||||
%4 = torch.prim.GetAttr %arg0["s1"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
%5 = torch.prim.GetAttr %arg0["s2"] : !torch.nn.Module<"__torch__.TestModule"> -> !torch.nn.Module<"__torch__.Submodule">
|
||||
call @__torch__.Submodule.forward(%4) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
call @__torch__.Submodule.forward(%5) : (!torch.nn.Module<"__torch__.Submodule">) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func private @s1.forward() {
|
||||
// CHECK: %[[C1:.*]] = constant 1 : i64
|
||||
// CHECK: %[[N:.*]] = torch.global_slot.get @s1.n : i64
|
||||
// CHECK: %[[NEWVAL:.*]] = torch.kernel_call "aten::add" %[[N]], %[[C1]] : (i64, i64) -> i64 {sigArgTypes = ["int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["int"]}
|
||||
// CHECK: torch.global_slot.set @s1.n = %[[NEWVAL]] : i64
|
||||
// CHECK: return
|
||||
|
||||
// CHECK-LABEL: func private @s2.forward() {
|
||||
// CHECK: %[[C1:.*]] = constant 1 : i64
|
||||
// CHECK: %[[N:.*]] = torch.global_slot.get @s2.n : i64
|
||||
// CHECK: %[[NEWVAL:.*]] = torch.kernel_call "aten::add" %[[N]], %[[C1]] : (i64, i64) -> i64 {sigArgTypes = ["int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["int"]}
|
||||
// CHECK: torch.global_slot.set @s2.n = %[[NEWVAL]] : i64
|
||||
// CHECK: return
|
||||
func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
|
||||
%c1_i64 = constant 1 : i64
|
||||
%5 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> i64
|
||||
%6 = torch.kernel_call "aten::add" %5, %c1_i64 : (i64, i64) -> i64 {sigArgTypes = ["int", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["int"]}
|
||||
torch.prim.SetAttr %arg0["n"] = %6 : !torch.nn.Module<"__torch__.Submodule">, i64
|
||||
return
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
// RUN: npcomp-opt -torch-prepare-for-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||
|
||||
torch.class_type @c {
|
||||
torch.method "test_call_method", @test_call_method
|
||||
torch.method "test_call_indirect", @test_call_indirect
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @test_call_method(
|
||||
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
|
||||
// CHECK-SAME: %[[F:.*]]: f64) -> f64 {
|
||||
// CHECK: %[[RET:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, f64) -> f64
|
||||
// CHECK: return %[[RET]] : f64
|
||||
func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 {
|
||||
%0 = torch.prim.CallMethod %arg0["test_call_method"] (%arg1) : !torch.nn.Module<"c">, (f64) -> f64
|
||||
return %0 : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @test_call_indirect(
|
||||
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
|
||||
// CHECK-SAME: %[[F:.*]]: f64) -> f64 {
|
||||
// Ensure no std.constant.
|
||||
// CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, f64) -> f64
|
||||
// CHECK-NEXT: return %[[VAL_2]] : f64
|
||||
func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: f64) -> f64 {
|
||||
%0 = constant @test_call_method : (!torch.nn.Module<"c">, f64) -> f64
|
||||
%1 = call_indirect %0(%arg0, %arg1) : (!torch.nn.Module<"c">, f64) -> f64
|
||||
return %1 : f64
|
||||
}
|
||||
|
||||
torch.nn_module {
|
||||
} : !torch.nn.Module<"c">
|
Loading…
Reference in New Issue