GlobalizeObjectGraph: Clean up handling of unused slots

The way we did it previously still created the slot and copied the
initializer even if unused.
pull/1030/head
Sean Silva 2022-07-12 01:07:24 +00:00
parent 9017be9e9e
commit e5e11e214b
7 changed files with 48 additions and 30 deletions

View File

@ -104,12 +104,20 @@ public:
private:
LogicalResult collectUsedSlots() {
// Collect all the slots in each module.
llvm::StringMap<llvm::StringMap<SlotOp>> moduleClassNameToSlots;
// moduleClassNameToSlots tracks, for each class, for each attribute, the
// set of slot instances that belong to that attribute. E.g. if there are
// two instances of a class "Foo" with an attribute "a", then there will be
// two SlotOps in the inner vector of moduleClassNameToSlots["Foo"]["a"].
// This is not precise -- in the code below it effectively results in the
// conservative assumption that all instances of a class might reach all
// GetAttr ops on that type.
llvm::StringMap<llvm::StringMap<std::vector<SlotOp>>>
moduleClassNameToSlots;
symbolTable.getOp()->walk([&](NnModuleOp moduleOp) {
llvm::StringMap<SlotOp> nameToSlot;
for (auto attrOp : moduleOp.getOps<SlotOp>())
nameToSlot[attrOp.name()] = attrOp;
moduleClassNameToSlots[moduleOp.getClassName()] = nameToSlot;
auto &slotNameToSlots = moduleClassNameToSlots[moduleOp.getClassName()];
for (auto slotOp : moduleOp.getOps<SlotOp>())
slotNameToSlots[slotOp.name()].push_back(slotOp);
});
// Find all the module slots that are accessed through `PrimGetAttrOp` or
@ -136,13 +144,14 @@ private:
op->emitError() << "Reference to non-existing module type "
<< moduleType.getClassName();
llvm::StringMap<SlotOp> nameToSlot = slots->getValue();
auto slotIt = nameToSlot.find(slotName);
auto &slotNameToSlots = slots->getValue();
auto slotIt = slotNameToSlots.find(slotName);
// TODO: Improve verifier so that this can never happen
if (slotIt == nameToSlot.end())
if (slotIt == slotNameToSlots.end())
op->emitError() << "Reference to non-existing module slot " << slotName
<< "in " << moduleType.getClassName();
usedSlots.insert(slotIt->getValue());
for (SlotOp slotOp : slotIt->getValue())
usedSlots.insert(slotOp);
});
return success();
}
@ -167,7 +176,8 @@ private:
if (failed(
recursivelyTraverse(slot.value().getDefiningOp<NnModuleOp>())))
return failure();
} else {
} else if (usedSlots.find(slot) != usedSlots.end()) {
// Only create the GlobalSlotOp if the slot is used at all.
std::string linkageName = llvm::join(nameStack, ".");
auto globalSlot = globalSlotBuilder.create<GlobalSlotOp>(
slot.getLoc(), linkageName,
@ -218,8 +228,6 @@ private:
for (Value result : op->getResults()) {
if (!hasMeaningfulObjectIdentity(result.getType()))
continue;
if (usedSlots.find(slot) == usedSlots.end())
continue;
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
.second) {
return op->emitError() << "potentially-aliased value used to "

View File

@ -40,23 +40,10 @@ torch.nn_module {
torch.slot "t", %t : !torch.tensor
} : !torch.nn.Module<"c">
// -----
// CHECK-LABEL: torch.global_slot @t1 : !torch.tensor {
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
// CHECK-LABEL: torch.global_slot @t2 : !torch.tensor {
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
torch.class_type @c {
torch.attr "t1" : !torch.tensor
torch.attr "t2" : !torch.tensor
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) {
%0 = torch.prim.GetAttr %arg0["b"] : !torch.nn.Module<"c"> -> !torch.bool
%1 = torch.prim.GetAttr %arg0["i"] : !torch.nn.Module<"c"> -> !torch.int
%2 = torch.prim.GetAttr %arg0["f"] : !torch.nn.Module<"c"> -> !torch.float
%3 = torch.prim.GetAttr %arg0["t"] : !torch.nn.Module<"c"> -> !torch.tensor
return
}
torch.nn_module {
torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c">

View File

@ -29,6 +29,13 @@ torch.class_type @parent {
torch.slot "m2", %child : !torch.nn.Module<"child">
} : !torch.nn.Module<"parent">
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"child">) {
%0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child">
%1 = torch.prim.GetAttr %arg0["m2"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child">
%2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float
return
}
// -----
torch.class_type @c {

View File

@ -19,3 +19,8 @@ torch.class_type @c {
torch.nn_module {
torch.slot "l", %l2 : !torch.list<list<list<tensor>>>
} : !torch.nn.Module<"c">
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) {
%0 = torch.prim.GetAttr %arg0["l"] : !torch.nn.Module<"c"> -> !torch.list<list<list<tensor>>>
return
}

View File

@ -70,5 +70,6 @@ func.func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Su
// CHECK-LABEL: func.func private @s1.forward() {
// CHECK: return
func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
%0 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int
return
}

View File

@ -22,3 +22,8 @@ torch.class_type @parent {
%parent = torch.nn_module {
torch.slot "m", %child : !torch.nn.Module<"child">
} : !torch.nn.Module<"parent">
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"child">) {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"child"> -> !torch.float
return
}

View File

@ -15,3 +15,8 @@ func.func private @method(%arg0: !torch.nn.Module<"c">) {
torch.nn_module {
torch.slot "float", %c42 : !torch.float
} : !torch.nn.Module<"c">
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) {
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float
return
}