diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index b59eae5f9..981fa04b8 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -250,6 +250,7 @@ def Torch_AttrOp : Torch_Op<"attr", [ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ Symbol, IsolatedFromAbove, + SingleBlockImplicitTerminator<"::mlir::NPCOMP::Torch::GlobalSlotInitOp"> ]> { let summary = "A slot with global storage"; let description = [{ @@ -265,21 +266,32 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ TypeAttr:$typeBound ); let results = (outs); + let regions = (region SizedRegion<1>:$initializer); let assemblyFormat = [{ - ($sym_visibility^)? $sym_name attr-dict `:` $typeBound + ($sym_visibility^)? $sym_name attr-dict `:` $typeBound ($initializer^)? + }]; +} + +def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [ + Terminator, + HasParent<"::mlir::NPCOMP::Torch::GlobalSlotOp">]> { + let summary = "yield-like terminator for torch.global_slot initializer region"; + let description = [{ + The operand to this op becomes the initial value of the parent + torch.global_slot. }]; - let extraClassDeclaration = [{ - // The name of the function, which, for semantic correctness, must be called - // exactly once and this call must be done before any other calls into - // the module. - // TODO: Avoid load-bearing names. - // We could replace this with an op that marks the function as initializer. - static constexpr StringRef getGlobalSlotInitializerFuncName() { - return "__torch_global_slot_initializer"; - } - }]; + let arguments = (ins AnyTorchType:$initialValue); + let results = (outs); + + // This bulider creates an illegal op, but is needed to appease + // ensureTerminator in the default builders for SingleBlockImplicitTerminator + // on the parent torch.global_slot op. + // TODO: Have a SingleBlockExplicitTerminator trait. + let builders = [OpBuilderDAG<(ins), [{ /*nothing to do */ }]>]; + + let assemblyFormat = "$initialValue attr-dict `:` type($initialValue)"; } def Torch_GlobalSlotGetOp : Torch_Op<"global_slot.get", []> { diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.td b/include/npcomp/Dialect/Torch/Transforms/Passes.td index ab427fd37..18a2ee603 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.td +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.td @@ -59,6 +59,14 @@ def GlobalizeObjectGraph : Pass<"torch-globalize-object-graph", "ModuleOp"> { paths. Or worse, infinite number of paths when considering cyclic object graphs. Also as of Feb 2021, TorchScript won't import into this form (it has a bug related to the identity of submodules). + - Two slots cannot have initial values that alias each other. + - Rationale: This makes the representation of initial values simpler. Also + as of Feb 2021, TorchScript won't import into this form except + potentially for Tensors (it has a bug related to the identity of + objects). And for tensors, the npcomp IValue importer only supports a + very restricted form of aliasing anyway for other reasons. We are + waiting for signals that more general handling of object aliasing is + important to devote the effort to it. }]; } diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 9c167541f..f42f6f755 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -33,7 +33,8 @@ private: FailureOr findRootNnModule(); LogicalResult checkSingleInstanceOfEachClass(); LogicalResult recursivelyTraverseClassType(ClassTypeOp classType); - void createInitializerFunc(); + LogicalResult populateGlobalSlotInitializer(GlobalSlotOp op, + Value initialValue); LogicalResult rewriteMethods(); void removeObjectGraph(); @@ -72,6 +73,11 @@ private: // Used for diagnostics. // The map value is the original path from the root that we found it at. DenseMap seenClassTypes; + + // A set of values that we have copied into torch.global_slot initializers, + // which cannot be used in multiple initializers because their object + // identity is important. + DenseSet objectsWithIdentityAlreadyCopiedIntoInitializers; }; } // namespace @@ -110,9 +116,6 @@ LogicalResult ObjectGraphGlobalizer::globalizeObjectGraph() { if (failed(recursivelyTraverseClassType(rootClassType))) return failure(); - // Move all slot initial values into an initializer func. - createInitializerFunc(); - // Rewrite torch.prim.GetAttr/torch.prim.SetAttr/torch.prim.CallMethod. if (failed(rewriteMethods())) return failure(); @@ -195,6 +198,9 @@ ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) { AttrOfClass attrOfClass = {classType, attr.name()}; assert(globalSlotForAttr.find(attrOfClass) == globalSlotForAttr.end()); globalSlotForAttr[attrOfClass] = globalSlot; + if (failed(populateGlobalSlotInitializer(globalSlot, + slotInitialValues[attrOfClass]))) + return failure(); } nameStack.pop_back(); } @@ -210,33 +216,48 @@ ObjectGraphGlobalizer::recursivelyTraverseClassType(ClassTypeOp classType) { return success(); } -void ObjectGraphGlobalizer::createInitializerFunc() { - auto loc = module.getLoc(); - auto func = globalBuilder.create( - loc, GlobalSlotOp::getGlobalSlotInitializerFuncName(), - globalBuilder.getFunctionType({}, {})); - OpBuilder builder(func.getContext()); - Block *body = builder.createBlock(&func.getBody()); +static bool hasMeaningfulObjectIdentity(Type type) { + return !type.isa(); +} - SmallVector opsToMove; - for (Operation &op : llvm::make_early_inc_range(*module.getBody())) { - if (isa( - &op)) +LogicalResult +ObjectGraphGlobalizer::populateGlobalSlotInitializer(GlobalSlotOp globalSlot, + Value initialValue) { + OpBuilder builder(globalSlot.getContext()); + builder.createBlock(&globalSlot.getRegion()); + + SmallPtrSet needToClone; + SmallVector worklist = {initialValue.getDefiningOp()}; + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!needToClone.insert(op).second) continue; - op.moveBefore(body, body->end()); - for (Value result : llvm::make_early_inc_range(op.getResults())) { - auto it = slotInitialValuesInverseMap.find(result); - if (it == slotInitialValuesInverseMap.end()) + for (Value operand : op->getOperands()) { + if (auto def = operand.getDefiningOp()) + worklist.push_back(def); + } + } + worklist.assign(needToClone.begin(), needToClone.end()); + llvm::sort(worklist, [](Operation *lhs, Operation *rhs) { + return lhs->isBeforeInBlock(rhs); + }); + BlockAndValueMapping mapping; + for (Operation *op : worklist) { + builder.clone(*op, mapping); + for (Value result : op->getResults()) { + if (!hasMeaningfulObjectIdentity(result.getType())) continue; - for (AttrOfClass attrOfClass : it->second) { - GlobalSlotOp globalSlot = globalSlotForAttr[attrOfClass]; - OpBuilder::atBlockEnd(body).create( - globalSlot.getLoc(), globalSlot.sym_name(), result); + if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result) + .second) { + return op->emitError() + << "potentially-aliased value used to initialize multiple slots"; } } } - - builder.create(loc); + builder.create(globalSlot->getLoc(), + mapping.lookup(initialValue)); + return success(); } // Verify that a value conforms to the subset of allowed uses for @@ -374,7 +395,7 @@ LogicalResult ObjectGraphGlobalizer::rewriteMethods() { void ObjectGraphGlobalizer::removeObjectGraph() { for (Operation &op : llvm::make_early_inc_range(*module.getBody())) { - if (isa(op)) { + if (!isa(op)) { op.dropAllDefinedValueUses(); op.erase(); } diff --git a/test/Dialect/Torch/globalize-object-graph-error.mlir b/test/Dialect/Torch/globalize-object-graph-error.mlir index 650c53ce4..3e73b78cc 100644 --- a/test/Dialect/Torch/globalize-object-graph-error.mlir +++ b/test/Dialect/Torch/globalize-object-graph-error.mlir @@ -47,3 +47,18 @@ torch.class_type @parent { torch.slot "m", %child : !torch.nn.Module<"child"> torch.slot "m2", %child : !torch.nn.Module<"child"> } : !torch.nn.Module<"parent"> + +// ----- + +torch.class_type @c { + torch.attr "a1" : !numpy.ndarray<*:!numpy.any_dtype> + torch.attr "a2" : !numpy.ndarray<*:!numpy.any_dtype> +} + +%cst = constant dense<1.000000e+00> : tensor<1xf32> +// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}} +%a = numpy.create_array_from_tensor %cst : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> +torch.nn_module { + torch.slot "a1", %a : !numpy.ndarray<*:!numpy.any_dtype> + torch.slot "a2", %a : !numpy.ndarray<*:!numpy.any_dtype> +} : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/globalize-object-graph-initializers.mlir b/test/Dialect/Torch/globalize-object-graph-initializers.mlir new file mode 100644 index 000000000..5a97cd3c2 --- /dev/null +++ b/test/Dialect/Torch/globalize-object-graph-initializers.mlir @@ -0,0 +1,21 @@ +// RUN: npcomp-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s + +// CHECK that multiple nested initialization ops are properly handled. + +// CHECK-LABEL: torch.global_slot @l : !basicpy.ListType { +// CHECK: %[[L0:.*]] = basicpy.build_list : () -> !basicpy.ListType +// CHECK: %[[L1:.*]] = basicpy.build_list %[[L0]], %[[L0]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType +// CHECK: %[[L2:.*]] = basicpy.build_list %[[L1]], %[[L1]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType +// CHECK: torch.global_slot.init %[[L2]] : !basicpy.ListType +// CHECK: } + +torch.class_type @c { + torch.attr "l" : !basicpy.ListType +} + +%l0 = basicpy.build_list : () -> !basicpy.ListType +%l1 = basicpy.build_list %l0, %l0 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType +%l2 = basicpy.build_list %l1, %l1 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType +torch.nn_module { + torch.slot "l", %l2 : !basicpy.ListType +} : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/globalize-object-graph-submodules.mlir b/test/Dialect/Torch/globalize-object-graph-submodules.mlir index 544640b4d..6b89c87e2 100644 --- a/test/Dialect/Torch/globalize-object-graph-submodules.mlir +++ b/test/Dialect/Torch/globalize-object-graph-submodules.mlir @@ -2,12 +2,11 @@ // Check that linkage names consist of the dotted path from the root. -// CHECK-LABEL: torch.global_slot @m.float : f64 +// CHECK-LABEL: torch.global_slot @m.float : f64 { +// CHECK: %[[INIT:.*]] = constant 4.200000e+01 : f64 +// CHECK: torch.global_slot.init %[[INIT]] : f64 +// CHECK: } -// CHECK-LABEL: func @__torch_global_slot_initializer() { -// CHECK: %[[C42:.*]] = constant 4.200000e+01 : f64 -// CHECK: torch.global_slot.set @m.float = %[[C42]] : f64 -// CHECK: return torch.class_type @child { torch.attr "float" : f64 diff --git a/test/Dialect/Torch/globalize-object-graph.mlir b/test/Dialect/Torch/globalize-object-graph.mlir index 0e51d0141..c52f07091 100644 --- a/test/Dialect/Torch/globalize-object-graph.mlir +++ b/test/Dialect/Torch/globalize-object-graph.mlir @@ -2,22 +2,26 @@ // Basic case. -// CHECK-LABEL: torch.global_slot @b : !basicpy.BoolType -// CHECK: torch.global_slot @i : i64 -// CHECK: torch.global_slot @f : f64 -// CHECK: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype> +// CHECK-LABEL: torch.global_slot @b : !basicpy.BoolType { +// CHECK: %[[INIT:.*]] = basicpy.bool_constant true +// CHECK: torch.global_slot.init %[[INIT]] : !basicpy.BoolType +// CHECK: } -// CHECK-LABEL: func @__torch_global_slot_initializer() { -// CHECK: %[[CB:.*]] = basicpy.bool_constant true -// CHECK: torch.global_slot.set @b = %[[CB]] : !basicpy.BoolType -// CHECK: %[[CI:.*]] = basicpy.numeric_constant 3 : i64 -// CHECK: torch.global_slot.set @i = %[[CI]] : i64 -// CHECK: %[[CF:.*]] = basicpy.numeric_constant 4.250000e+01 : f64 -// CHECK: torch.global_slot.set @f = %[[CF]] : f64 +// CHECK-LABEL: torch.global_slot @i : i64 { +// CHECK: %[[INIT:.*]] = basicpy.numeric_constant 3 : i64 +// CHECK: torch.global_slot.init %[[INIT]] : i64 +// CHECK: } + +// CHECK-LABEL: torch.global_slot @f : f64 { +// CHECK: %[[INIT:.*]] = basicpy.numeric_constant 4.250000e+01 : f64 +// CHECK: torch.global_slot.init %[[INIT]] : f64 +// CHECK: } + +// CHECK-LABEL: torch.global_slot @a : !numpy.ndarray<*:!numpy.any_dtype> { // CHECK: %[[C:.*]] = constant dense<1.000000e+00> : tensor<1xf32> -// CHECK: %[[CA:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> -// CHECK: torch.global_slot.set @a = %[[CA]] : !numpy.ndarray<*:!numpy.any_dtype> -// CHECK: return +// CHECK: %[[A:.*]] = numpy.create_array_from_tensor %[[C]] : (tensor<1xf32>) -> !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: torch.global_slot.init %[[A]] : !numpy.ndarray<*:!numpy.any_dtype> +// CHECK: } torch.class_type @c { torch.attr "b" : !basicpy.BoolType @@ -37,27 +41,3 @@ torch.nn_module { torch.slot "f", %f : f64 torch.slot "a", %a : !numpy.ndarray<*:!numpy.any_dtype> } : !torch.nn.Module<"c"> - -// ----- - -// Same SSA value used as initializer for multiple slots. - -// CHECK-LABEL: torch.global_slot @b1 : !basicpy.BoolType -// CHECK-LABEL: torch.global_slot @b2 : !basicpy.BoolType -// CHECK-LABEL: func @__torch_global_slot_initializer() { -// CHECK: %[[TRUE:.*]] = basicpy.bool_constant true -// CHECK: torch.global_slot.set @b1 = %[[TRUE]] : !basicpy.BoolType -// CHECK: torch.global_slot.set @b2 = %[[TRUE]] : !basicpy.BoolType -// CHECK: return -// CHECK: } - -torch.class_type @c { - torch.attr "b1" : !basicpy.BoolType - torch.attr "b2" : !basicpy.BoolType -} - -%bool_true = basicpy.bool_constant true -torch.nn_module { - torch.slot "b1", %bool_true : !basicpy.BoolType - torch.slot "b2", %bool_true : !basicpy.BoolType -} : !torch.nn.Module<"c">