mirror of https://github.com/llvm/torch-mlir
GlobalizeObjectGraph: Clean up handling of unused slots
The way we did it previously still created the slot and copied the initializer even if unused.pull/1030/head
parent
9017be9e9e
commit
e5e11e214b
|
@ -104,12 +104,20 @@ public:
|
|||
private:
|
||||
LogicalResult collectUsedSlots() {
|
||||
// Collect all the slots in each module.
|
||||
llvm::StringMap<llvm::StringMap<SlotOp>> moduleClassNameToSlots;
|
||||
// moduleClassNameToSlots tracks, for each class, for each attribute, the
|
||||
// set of slot instances that belong to that attribute. E.g. if there are
|
||||
// two instances of a class "Foo" with an attribute "a", then there will be
|
||||
// two SlotOps in the inner vector of moduleClassNameToSlots["Foo"]["a"].
|
||||
// This is not precise -- in the code below it effectively results in the
|
||||
// conservative assumption that all instances of a class might reach all
|
||||
// GetAttr ops on that type.
|
||||
llvm::StringMap<llvm::StringMap<std::vector<SlotOp>>>
|
||||
moduleClassNameToSlots;
|
||||
symbolTable.getOp()->walk([&](NnModuleOp moduleOp) {
|
||||
llvm::StringMap<SlotOp> nameToSlot;
|
||||
for (auto attrOp : moduleOp.getOps<SlotOp>())
|
||||
nameToSlot[attrOp.name()] = attrOp;
|
||||
moduleClassNameToSlots[moduleOp.getClassName()] = nameToSlot;
|
||||
auto &slotNameToSlots = moduleClassNameToSlots[moduleOp.getClassName()];
|
||||
for (auto slotOp : moduleOp.getOps<SlotOp>())
|
||||
slotNameToSlots[slotOp.name()].push_back(slotOp);
|
||||
});
|
||||
|
||||
// Find all the module slots that are accessed through `PrimGetAttrOp` or
|
||||
|
@ -136,13 +144,14 @@ private:
|
|||
op->emitError() << "Reference to non-existing module type "
|
||||
<< moduleType.getClassName();
|
||||
|
||||
llvm::StringMap<SlotOp> nameToSlot = slots->getValue();
|
||||
auto slotIt = nameToSlot.find(slotName);
|
||||
auto &slotNameToSlots = slots->getValue();
|
||||
auto slotIt = slotNameToSlots.find(slotName);
|
||||
// TODO: Improve verifier so that this can never happen
|
||||
if (slotIt == nameToSlot.end())
|
||||
if (slotIt == slotNameToSlots.end())
|
||||
op->emitError() << "Reference to non-existing module slot " << slotName
|
||||
<< "in " << moduleType.getClassName();
|
||||
usedSlots.insert(slotIt->getValue());
|
||||
for (SlotOp slotOp : slotIt->getValue())
|
||||
usedSlots.insert(slotOp);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
@ -167,7 +176,8 @@ private:
|
|||
if (failed(
|
||||
recursivelyTraverse(slot.value().getDefiningOp<NnModuleOp>())))
|
||||
return failure();
|
||||
} else {
|
||||
} else if (usedSlots.find(slot) != usedSlots.end()) {
|
||||
// Only create the GlobalSlotOp if the slot is used at all.
|
||||
std::string linkageName = llvm::join(nameStack, ".");
|
||||
auto globalSlot = globalSlotBuilder.create<GlobalSlotOp>(
|
||||
slot.getLoc(), linkageName,
|
||||
|
@ -218,8 +228,6 @@ private:
|
|||
for (Value result : op->getResults()) {
|
||||
if (!hasMeaningfulObjectIdentity(result.getType()))
|
||||
continue;
|
||||
if (usedSlots.find(slot) == usedSlots.end())
|
||||
continue;
|
||||
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
|
||||
.second) {
|
||||
return op->emitError() << "potentially-aliased value used to "
|
||||
|
|
|
@ -40,23 +40,10 @@ torch.nn_module {
|
|||
torch.slot "t", %t : !torch.tensor
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @t1 : !torch.tensor {
|
||||
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @t2 : !torch.tensor {
|
||||
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
|
||||
|
||||
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
torch.class_type @c {
|
||||
torch.attr "t1" : !torch.tensor
|
||||
torch.attr "t2" : !torch.tensor
|
||||
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) {
|
||||
%0 = torch.prim.GetAttr %arg0["b"] : !torch.nn.Module<"c"> -> !torch.bool
|
||||
%1 = torch.prim.GetAttr %arg0["i"] : !torch.nn.Module<"c"> -> !torch.int
|
||||
%2 = torch.prim.GetAttr %arg0["f"] : !torch.nn.Module<"c"> -> !torch.float
|
||||
%3 = torch.prim.GetAttr %arg0["t"] : !torch.nn.Module<"c"> -> !torch.tensor
|
||||
return
|
||||
}
|
||||
torch.nn_module {
|
||||
torch.slot "t1", %t : !torch.tensor
|
||||
torch.slot "t2", %t : !torch.tensor
|
||||
} : !torch.nn.Module<"c">
|
||||
|
|
|
@ -29,6 +29,13 @@ torch.class_type @parent {
|
|||
torch.slot "m2", %child : !torch.nn.Module<"child">
|
||||
} : !torch.nn.Module<"parent">
|
||||
|
||||
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"child">) {
|
||||
%0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child">
|
||||
%1 = torch.prim.GetAttr %arg0["m2"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child">
|
||||
%2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
|
|
|
@ -19,3 +19,8 @@ torch.class_type @c {
|
|||
torch.nn_module {
|
||||
torch.slot "l", %l2 : !torch.list<list<list<tensor>>>
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) {
|
||||
%0 = torch.prim.GetAttr %arg0["l"] : !torch.nn.Module<"c"> -> !torch.list<list<list<tensor>>>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -70,5 +70,6 @@ func.func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Su
|
|||
// CHECK-LABEL: func.func private @s1.forward() {
|
||||
// CHECK: return
|
||||
func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) {
|
||||
%0 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int
|
||||
return
|
||||
}
|
||||
|
|
|
@ -22,3 +22,8 @@ torch.class_type @parent {
|
|||
%parent = torch.nn_module {
|
||||
torch.slot "m", %child : !torch.nn.Module<"child">
|
||||
} : !torch.nn.Module<"parent">
|
||||
|
||||
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"child">) {
|
||||
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"child"> -> !torch.float
|
||||
return
|
||||
}
|
||||
|
|
|
@ -15,3 +15,8 @@ func.func private @method(%arg0: !torch.nn.Module<"c">) {
|
|||
torch.nn_module {
|
||||
torch.slot "float", %c42 : !torch.float
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) {
|
||||
%0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue