From e5e11e214b65dfaabfb480c7b2961b1dec26ed37 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 12 Jul 2022 01:07:24 +0000 Subject: [PATCH] GlobalizeObjectGraph: Clean up handling of unused slots The way we did it previously still created the slot and copied the initializer even if unused. --- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 30 ++++++++++++------- .../Torch/GlobalizeObjectGraph/basic.mlir | 25 ++++------------ .../Torch/GlobalizeObjectGraph/error.mlir | 7 +++++ .../GlobalizeObjectGraph/initializers.mlir | 5 ++++ ...ltiple-instances-multiple-module-args.mlir | 1 + .../GlobalizeObjectGraph/submodules.mlir | 5 ++++ .../GlobalizeObjectGraph/visibility.mlir | 5 ++++ 7 files changed, 48 insertions(+), 30 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 09f7d7d24..8ca07604d 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -104,12 +104,20 @@ public: private: LogicalResult collectUsedSlots() { // Collect all the slots in each module. - llvm::StringMap> moduleClassNameToSlots; + // moduleClassNameToSlots tracks, for each class, for each attribute, the + // set of slot instances that belong to that attribute. E.g. if there are + // two instances of a class "Foo" with an attribute "a", then there will be + // two SlotOps in the inner vector of moduleClassNameToSlots["Foo"]["a"]. + // This is not precise -- in the code below it effectively results in the + // conservative assumption that all instances of a class might reach all + // GetAttr ops on that type. + llvm::StringMap>> + moduleClassNameToSlots; symbolTable.getOp()->walk([&](NnModuleOp moduleOp) { llvm::StringMap nameToSlot; - for (auto attrOp : moduleOp.getOps()) - nameToSlot[attrOp.name()] = attrOp; - moduleClassNameToSlots[moduleOp.getClassName()] = nameToSlot; + auto &slotNameToSlots = moduleClassNameToSlots[moduleOp.getClassName()]; + for (auto slotOp : moduleOp.getOps()) + slotNameToSlots[slotOp.name()].push_back(slotOp); }); // Find all the module slots that are accessed through `PrimGetAttrOp` or @@ -136,13 +144,14 @@ private: op->emitError() << "Reference to non-existing module type " << moduleType.getClassName(); - llvm::StringMap nameToSlot = slots->getValue(); - auto slotIt = nameToSlot.find(slotName); + auto &slotNameToSlots = slots->getValue(); + auto slotIt = slotNameToSlots.find(slotName); // TODO: Improve verifier so that this can never happen - if (slotIt == nameToSlot.end()) + if (slotIt == slotNameToSlots.end()) op->emitError() << "Reference to non-existing module slot " << slotName << "in " << moduleType.getClassName(); - usedSlots.insert(slotIt->getValue()); + for (SlotOp slotOp : slotIt->getValue()) + usedSlots.insert(slotOp); }); return success(); } @@ -167,7 +176,8 @@ private: if (failed( recursivelyTraverse(slot.value().getDefiningOp()))) return failure(); - } else { + } else if (usedSlots.find(slot) != usedSlots.end()) { + // Only create the GlobalSlotOp if the slot is used at all. std::string linkageName = llvm::join(nameStack, "."); auto globalSlot = globalSlotBuilder.create( slot.getLoc(), linkageName, @@ -218,8 +228,6 @@ private: for (Value result : op->getResults()) { if (!hasMeaningfulObjectIdentity(result.getType())) continue; - if (usedSlots.find(slot) == usedSlots.end()) - continue; if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result) .second) { return op->emitError() << "potentially-aliased value used to " diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir index 3f9e6986a..33fd2ee85 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/basic.mlir @@ -40,23 +40,10 @@ torch.nn_module { torch.slot "t", %t : !torch.tensor } : !torch.nn.Module<"c"> - -// ----- - -// CHECK-LABEL: torch.global_slot @t1 : !torch.tensor { -// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor -// CHECK: torch.global_slot.init %[[T]] : !torch.tensor - -// CHECK-LABEL: torch.global_slot @t2 : !torch.tensor { -// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor -// CHECK: torch.global_slot.init %[[T]] : !torch.tensor - -%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor -torch.class_type @c { - torch.attr "t1" : !torch.tensor - torch.attr "t2" : !torch.tensor +func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) { + %0 = torch.prim.GetAttr %arg0["b"] : !torch.nn.Module<"c"> -> !torch.bool + %1 = torch.prim.GetAttr %arg0["i"] : !torch.nn.Module<"c"> -> !torch.int + %2 = torch.prim.GetAttr %arg0["f"] : !torch.nn.Module<"c"> -> !torch.float + %3 = torch.prim.GetAttr %arg0["t"] : !torch.nn.Module<"c"> -> !torch.tensor + return } -torch.nn_module { - torch.slot "t1", %t : !torch.tensor - torch.slot "t2", %t : !torch.tensor -} : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir index 43734b6e8..8bf281c18 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir @@ -29,6 +29,13 @@ torch.class_type @parent { torch.slot "m2", %child : !torch.nn.Module<"child"> } : !torch.nn.Module<"parent"> +func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">, %arg1: !torch.nn.Module<"child">) { + %0 = torch.prim.GetAttr %arg0["m"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child"> + %1 = torch.prim.GetAttr %arg0["m2"] : !torch.nn.Module<"parent"> -> !torch.nn.Module<"child"> + %2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float + return +} + // ----- torch.class_type @c { diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir index 2ef351784..c48097316 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir @@ -19,3 +19,8 @@ torch.class_type @c { torch.nn_module { torch.slot "l", %l2 : !torch.list>> } : !torch.nn.Module<"c"> + +func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) { + %0 = torch.prim.GetAttr %arg0["l"] : !torch.nn.Module<"c"> -> !torch.list>> + return +} diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir index f4f84730d..6d5e94cdf 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/multiple-instances-multiple-module-args.mlir @@ -70,5 +70,6 @@ func.func private @__torch__.free_function(%arg0: !torch.nn.Module<"__torch__.Su // CHECK-LABEL: func.func private @s1.forward() { // CHECK: return func.func private @__torch__.Submodule.forward(%arg0: !torch.nn.Module<"__torch__.Submodule">) { + %0 = torch.prim.GetAttr %arg0["n"] : !torch.nn.Module<"__torch__.Submodule"> -> !torch.int return } diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir index 92eeec0f9..0b35e3cb2 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir @@ -22,3 +22,8 @@ torch.class_type @parent { %parent = torch.nn_module { torch.slot "m", %child : !torch.nn.Module<"child"> } : !torch.nn.Module<"parent"> + +func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"child">) { + %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"child"> -> !torch.float + return +} diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/visibility.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/visibility.mlir index 8ad6e5f48..25a2395ef 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/visibility.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/visibility.mlir @@ -15,3 +15,8 @@ func.func private @method(%arg0: !torch.nn.Module<"c">) { torch.nn_module { torch.slot "float", %c42 : !torch.float } : !torch.nn.Module<"c"> + +func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"c">) { + %0 = torch.prim.GetAttr %arg0["float"] : !torch.nn.Module<"c"> -> !torch.float + return +}