Rework how global slot initializers work.

Rather than a per-global-slot initializer region, we now have one for
the whole module. For example, it might look like this:

```
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
torch.global_slot.module_initializer {
  %0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
  %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
  torch.initialize.global_slots [
    @tensor(%0 : !torch.tensor)
    @list(%1 : !torch.list<tensor>)
  ]
}
```

This new structure allows GlobalizeObjectGraph to create the initializer in a
much simpler way, avoiding the need to reason about whether different slots
alias each other. Reasoning about whether slots alias each other now is the
responsibility of InlineGlobalSlots, which has to do a much more complicated
analysis, implemented using MLIR's dataflow analysis framework.

Recommended review order:
- Check out the new IR constructs in the .mlir files of various passes
- Op definitions (*.td)
- Changes to GlobalizeObjectGraph pass.
- InlineGlobalSlots pass (~total rewrite)
- Misc changes:
  - Moving torchMlirAdjustStaticInformation for sharing with C++ code.
  - EraseModuleInitializer pass

To make this a bit nicer, it would be good to have a `torch.module` op
with an initializer region attached. That would be more invasive though.

This change has highlighted certain aspects of our project layering
which are worth calling out. None of our backends can handle global
slots, so we enforce that there are no global slots before backend
lowering. At an earlier stage in the project, we had aspirations of
transparently handling mutable global state and such, but for reasons
described below, that is no longer a goal. So really global slots should
be seen as a progressive lowering step as part of inlining all the
IValue's in the original program (GlobalizeObjectGraph is also one such
step).

Over time, with insights from work like IREE-JAX, it has become clear
that there isn't a reliable programming model we can compile for users
where we just transparently handle mutable global state (and some other
things, like lists and dictionaries). There is a need for an "outer
program" that orchestrates more restricted subroutines of the kind we
can handle in our compile flow here. The benefit of that is that it
decouples considerations like shapes, dtypes, etc. from the program
constructs used in the outer program. As long as the outer program can
efficiently invoke (pipelining/async/etc.) high-performance
data-parallel numerical subroutines of the kind we compile in our flow
here, then there is a complete programming model. This is also
consistent with the direction of upstream PyTorch which is becoming more
tracing-based (which inherently loses a lot of program structure, which
then has to be applied back with an "outer program" orchestrating the
traced subroutines).
pull/1187/head
Sean Silva 2022-07-13 18:45:56 +00:00
parent 34e207eeb5
commit 504de5e701
24 changed files with 1137 additions and 267 deletions

View File

@ -187,6 +187,15 @@ m_TorchTensorSizeInt(Value tensor, int64_t *dim) {
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
Value tensor);
/// Adjusts the static information in the type of `value` to `desiredType`.
///
/// Returns null if such an adjustment is not possible.
///
/// If `userAllowsRefinement` is true, then the original value will be returned
/// if it is a subtype of `desiredType`.
Value adjustStaticInformation(OpBuilder &builder, Location loc, Value value,
Type desiredType, bool userAllowsRefinement);
/// Returns true if `list` is potentially mutated.
bool isListPotentiallyMutated(Value list);

View File

@ -228,8 +228,6 @@ def Torch_AttrOp : Torch_Op<"attr", [
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
Symbol,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp">
]> {
let summary = "A slot with global storage";
let description = [{
@ -245,17 +243,66 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
TypeAttr:$typeBound
);
let results = (outs);
let assemblyFormat = [{
($sym_visibility^)? $sym_name attr-dict `:` $typeBound
}];
}
def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [
IsolatedFromAbove,
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">
]> {
let summary = "Module initializer for all `torch.global_slot` ops";
let description = [{
Initializer function that runs once at program startup to initialize
all `torch.global_slot` ops in the module.
The only ops that should be in the module initializer should be ops
generated by the IValue importer. This set avoids the need to define
the behavior in case of certain kinds of side effects in the initializer
(except for the side effect of updating the torch.global_slot ops with the
`torch.initialize.global_slots` op).
}];
let arguments = (ins);
let results = (outs);
let regions = (region SizedRegion<1>:$initializer);
let assemblyFormat = [{
($sym_visibility^)? $sym_name attr-dict `:` $typeBound ($initializer^)?
$initializer attr-dict
}];
let hasVerifier = 1;
}
def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
Terminator,
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> {
let summary = "Terminator for torch.global_slot.module_initializer region";
let description = [{
Atomically updates the value of all the global slots named in `slotSymNames`
with the corresponding values provided in `initialValues`.
}];
let arguments = (ins
SymbolRefArrayAttr:$slotSymNames,
Variadic<AnyTorchType>:$initialValues
);
let results = (outs);
// This builder creates an illegal op, but is needed to appease
// ensureTerminator in the default builders for SingleBlockImplicitTerminator
// on the parent op.
// TODO: Have a SingleBlockExplicitTerminator trait.
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
Terminator,
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
let summary = "yield-like terminator for torch.global_slot initializer region";
let summary = "yield-like terminator for torch.initialize.global_slotsr region";
let description = [{
The operand to this op becomes the initial value of the parent
torch.global_slot.

View File

@ -80,6 +80,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyConversionToValueSemanticsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createEraseModuleInitializerPass();
StringRef getShapeLibrary();
} // namespace Torch

View File

@ -143,7 +143,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
Note: This pass inlines everything that is safe to inline. That is, it
doesn't have a cost model. This is likely to pessimize programs with
significant amounts of computation inside torch.global_slot initializer
significant amounts of computation inside torch.initialize.global_slotsr
regions (but this currently doesn't happen due to how TorchScript modules
are imported -- the contents are just constants).
}];
@ -265,4 +265,17 @@ def VerifyConversionToValueSemantics
}];
}
def EraseModuleInitializer
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
let summary = "Erase the `torch.global_slot.module_initializer` op.";
let constructor =
"mlir::torch::Torch::createEraseModuleInitializerPass()";
let description = [{
Backends cannot currently handle module initializers, so we omit them from
our backend contract. This pass removes the
`torch.global_slot.module_initializer` op from the module if legal, or
raises an error.
}];
}
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -31,45 +31,8 @@ MlirValue torchMlirAdjustStaticInformation(MlirBlock block_,
OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_)));
builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator()
: block->end());
Value value = unwrap(value_);
Type type = value.getType();
Type desiredType = unwrap(desiredType_);
// If the value is already of the desired type, we're done.
if (type == desiredType)
return wrap(value);
// If the type is a tensor, then adjust the static information.
if ((type.isa<Torch::ValueTensorType>() &&
desiredType.isa<Torch::ValueTensorType>()) ||
(type.isa<Torch::NonValueTensorType>() &&
desiredType.isa<Torch::NonValueTensorType>())) {
Value adjusted = builder.create<Torch::TensorStaticInfoCastOp>(
value.getLoc(), desiredType, value);
return wrap(adjusted);
}
// If the type is a subtype of desiredType, then we need to derefine it to
// desiredType, unless the user allows refinement.
if (Torch::isValidSubtype(type, desiredType)) {
if (!userAllowsRefinement) {
Value adjusted =
builder.create<Torch::DerefineOp>(value.getLoc(), desiredType, value);
return wrap(adjusted);
} else {
return wrap(value);
}
}
// If the desiredType is subtype of type, then we assume that the desiredType
// is dynamically valid, so we do an unchecked cast.
if (Torch::isValidSubtype(desiredType, type)) {
Value adjusted = builder.create<Torch::PrimUncheckedCastOp>(
value.getLoc(), desiredType, value);
return wrap(adjusted);
}
// No known adjustment.
return {};
return wrap(Torch::adjustStaticInformation(
builder, value.getLoc(), value, desiredType, userAllowsRefinement));
}

View File

@ -28,6 +28,49 @@ using namespace mlir::torch::Torch;
// Utilities
//===----------------------------------------------------------------------===//
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
Location loc, Value value,
Type desiredType,
bool userAllowsRefinement) {
Type type = value.getType();
// If the value is already of the desired type, we're done.
if (type == desiredType)
return value;
// If the type is a tensor, then adjust the static information.
if ((type.isa<ValueTensorType>() && desiredType.isa<ValueTensorType>()) ||
(type.isa<NonValueTensorType>() &&
desiredType.isa<NonValueTensorType>())) {
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
desiredType, value);
return adjusted;
}
// If the type is a subtype of desiredType, then we need to derefine it to
// desiredType, unless the user allows refinement.
if (isValidSubtype(type, desiredType)) {
if (!userAllowsRefinement) {
Value adjusted =
builder.create<DerefineOp>(value.getLoc(), desiredType, value);
return adjusted;
} else {
return value;
}
}
// If the desiredType is subtype of type, then we assume that the desiredType
// is dynamically valid, so we do an unchecked cast.
if (isValidSubtype(desiredType, type)) {
Value adjusted =
builder.create<PrimUncheckedCastOp>(value.getLoc(), desiredType, value);
return adjusted;
}
// No known adjustment.
return Value();
}
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
BaseTensorType newType,
Value tensor) {
@ -1936,3 +1979,154 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return emitOpError("expected number of shapes to match number of results");
return success();
}
//===----------------------------------------------------------------------===//
// GlobalSlotModuleInitializerOp
//===----------------------------------------------------------------------===//
LogicalResult GlobalSlotModuleInitializerOp::verify() {
// We centralize all verification of the global slots and the
// InitializeGlobalSlotsOp into here, since it requires processing the whole
// module.
// TODO: We should really have a `torch.module` and have this initializer be
// a region attached to it.
ModuleOp module = cast<ModuleOp>(getOperation()->getParentOp());
for (auto op : module.getOps<GlobalSlotModuleInitializerOp>()) {
if (op.getOperation() != getOperation())
return op.emitError("there must be only one global slot initializer");
}
// Collect the relevant symbol names we will verify.
DenseSet</*StringAttr*/ Attribute> knownGlobalSlots;
for (auto op : module.getOps<GlobalSlotOp>())
knownGlobalSlots.insert(op.sym_nameAttr());
DenseSet</*StringAttr*/ Attribute> initializedGlobalSlots;
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
for (Attribute symName : initialize.slotSymNames()) {
auto wasInserted = initializedGlobalSlots
.insert(symName.cast<FlatSymbolRefAttr>().getAttr())
.second;
if (!wasInserted)
return initialize.emitError("duplicate initialization of global slot: ")
<< symName;
}
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
return lhs.cast<StringAttr>().getValue() <
rhs.cast<StringAttr>().getValue();
};
auto known = llvm::to_vector(knownGlobalSlots);
llvm::sort(known, lessThanByStringValue);
auto initialized = llvm::to_vector(initializedGlobalSlots);
llvm::sort(initialized, lessThanByStringValue);
// Check that the global slots in the module are all initialized.
SymbolTable symbolTable(module);
if (initializedGlobalSlots != knownGlobalSlots) {
InFlightDiagnostic diag = initialize.emitOpError(
"must have one initializer for each global slot in the module");
for (auto knownGlobalSlot : known) {
auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast<StringAttr>());
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
diag.attachNote(
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
.append("missing global slot initializer for ", symName);
}
}
for (auto initializedGlobalSlot : initialized) {
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
diag.attachNote().append(
"unexpected global slot initializer for non-existent global slot ",
FlatSymbolRefAttr::get(initializedGlobalSlot.cast<StringAttr>()));
}
}
return diag;
}
// Check that initial values satisfy type bounds.
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
auto symName = initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>();
auto initialValue = initialize.getOperand(i);
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
if (!isValidSubtype(initialValue.getType(), globalSlotOp.typeBound())) {
return initialize.emitOpError().append(
"initial value for global slot ", symName, " has type ",
initialValue.getType(), " which is not within the bound ",
globalSlotOp.typeBound());
}
}
auto walkResult = getOperation()->walk([](Operation *op) {
// We only permit a small set of ops in the module initializer.
// These ops are essentially those which can be produced by the IValue
// importer.
if (isa<GlobalSlotModuleInitializerOp, InitializeGlobalSlotsOp,
PrimListConstructOp, PrimDictConstructOp, PrimTupleConstructOp,
ConstantBoolOp, ConstantStrOp, ConstantIntOp, ConstantFloatOp,
ConstantNoneOp, NonValueTensorLiteralOp, PerTensorAffineCreateOp,
LinearParamsCreateOp>(op))
return WalkResult::advance();
op->emitOpError() << "is not allowed in a module initializer";
return WalkResult::interrupt();
});
if (walkResult.wasInterrupted())
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// InitializeGlobalSlotsOp
//===----------------------------------------------------------------------===//
ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
if (parser.parseLSquare())
return failure();
SmallVector<Attribute> slotSymNames;
while (!succeeded(parser.parseOptionalRSquare())) {
NamedAttrList dummy;
StringAttr slotSymName;
if (parser.parseSymbolName(slotSymName, "dummy", dummy))
return failure();
slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName));
if (parser.parseLParen())
return failure();
OpAsmParser::UnresolvedOperand initialValue;
if (parser.parseOperand(initialValue))
return failure();
Type initialValueType;
if (parser.parseColonType(initialValueType))
return failure();
if (parser.parseRParen())
return failure();
if (parser.resolveOperand(initialValue, initialValueType, result.operands))
return failure();
}
result.addAttribute("slotSymNames",
ArrayAttr::get(parser.getContext(), slotSymNames));
return success();
}
void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(getOperation()->getAttrs(),
/*elidedAttrs=*/{"slotSymNames"});
p << " [";
p.printNewline();
for (int i = 0, e = getNumOperands(); i < e; ++i) {
p << " " << slotSymNames()[i] << "(" << initialValues()[i] << " : "
<< initialValues()[i].getType() << ")";
p.printNewline();
}
p << "]";
}
LogicalResult InitializeGlobalSlotsOp::verify() {
if (initialValues().size() != slotSymNames().size())
return emitOpError("expected number of operands to match number of slots");
return success();
}

View File

@ -2,6 +2,7 @@ add_mlir_library(TorchMLIRTorchPasses
AdjustCallingConventions.cpp
DecomposeComplexOps.cpp
DropShapeCalculations.cpp
EraseModuleInitializer.cpp
Passes.cpp
GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp

View File

@ -0,0 +1,50 @@
//===- EraseModuleInitializer.cpp --------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class EraseModuleInitializerPass
: public EraseModuleInitializerBase<EraseModuleInitializerPass> {
void runOnOperation() override {
auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) {
auto intialize =
cast<InitializeGlobalSlotsOp>(op.getBody()->getTerminator());
if (intialize.getNumOperands() != 0) {
op.emitError("could not erase non-empty module initializer");
return WalkResult::interrupt();
}
op.erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createEraseModuleInitializerPass() {
return std::make_unique<EraseModuleInitializerPass>();
}

View File

@ -48,12 +48,6 @@ static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
return rootNnModule;
}
static bool hasMeaningfulObjectIdentity(Type type) {
return !type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::StringType, Torch::NoneType,
Torch::ValueTensorType>();
}
//===----------------------------------------------------------------------===//
// Object graph recursive traversal.
//===----------------------------------------------------------------------===//
@ -100,6 +94,9 @@ public:
assert(it != slotToGlobalSlot.end() && "didn't create global slot");
return it->second;
}
llvm::MapVector<StringAttr, Value> &getGlobalSlotInitialValues() {
return globalSlotInitialValues;
}
private:
LogicalResult collectUsedSlots() {
@ -187,8 +184,7 @@ private:
assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end());
slotToGlobalSlot[slot] = globalSlot;
slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.isPrivate()};
if (failed(populateGlobalSlotInitializer(globalSlot, slot)))
return failure();
globalSlotInitialValues[globalSlot.sym_nameAttr()] = slot.value();
}
nameStack.pop_back();
}
@ -201,44 +197,6 @@ private:
}
return success();
}
LogicalResult populateGlobalSlotInitializer(GlobalSlotOp globalSlot,
SlotOp slot) {
OpBuilder builder(globalSlot.getContext());
builder.createBlock(&globalSlot.getRegion());
SmallPtrSet<Operation *, 6> needToClone;
Value initialValue = slot.value();
SmallVector<Operation *> worklist = {initialValue.getDefiningOp()};
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (!needToClone.insert(op).second)
continue;
for (Value operand : op->getOperands()) {
if (auto def = operand.getDefiningOp())
worklist.push_back(def);
}
}
worklist.assign(needToClone.begin(), needToClone.end());
llvm::sort(worklist, [](Operation *lhs, Operation *rhs) {
return lhs->isBeforeInBlock(rhs);
});
BlockAndValueMapping mapping;
for (Operation *op : worklist) {
builder.clone(*op, mapping);
for (Value result : op->getResults()) {
if (!hasMeaningfulObjectIdentity(result.getType()))
continue;
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
.second) {
return op->emitError() << "potentially-aliased value used to "
"initialize multiple slots";
}
}
}
builder.create<GlobalSlotInitOp>(globalSlot->getLoc(),
mapping.lookup(initialValue));
return success();
}
// Builder for creating GlobalSlotOp's in the module.
OpBuilder globalSlotBuilder;
// Symbol table for the module.
@ -262,16 +220,50 @@ private:
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
// The corresponding GlobalSlotOp for each SlotOp in the program.
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
// A set of values that we have copied into torch.global_slot initializers,
// which cannot be used in multiple initializers because their object
// identity is important.
DenseSet<Value> objectsWithIdentityAlreadyCopiedIntoInitializers;
// The initializing value for each GlobalSlotOp.
// This is a MapVector to keep the order deterministic.
llvm::MapVector<StringAttr, Value> globalSlotInitialValues;
// Used to keep track of all the used torch slots so that the restrictions can
// be applied to those slots only.
DenseSet<SlotOp> usedSlots;
};
} // namespace
LogicalResult
createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable,
ObjectGraphInfo &objectGraphInfo) {
auto builder = OpBuilder::atBlockBegin(module.getBody());
auto moduleInitializer =
builder.create<GlobalSlotModuleInitializerOp>(module.getLoc());
Block *body = builder.createBlock(&moduleInitializer.initializer());
builder.setInsertionPointToEnd(body);
SmallVector<Operation *> opsToMove;
for (Operation &op : *module.getBody()) {
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, func::FuncOp,
GlobalSlotModuleInitializerOp>(op))
continue;
opsToMove.push_back(&op);
}
BlockAndValueMapping mapping;
for (Operation *op : opsToMove) {
// The ops are used by `torch.slot` ops in the enclosing module.
// Cloning avoids needing to handle those uses specially.
builder.clone(*op, mapping);
}
SmallVector<Attribute> slotSymNames;
SmallVector<Value> initialValues;
for (auto &kv : objectGraphInfo.getGlobalSlotInitialValues()) {
StringAttr symName = kv.first;
Value initializer = kv.second;
slotSymNames.push_back(FlatSymbolRefAttr::get(symName));
initialValues.push_back(mapping.lookup(initializer));
}
builder.create<InitializeGlobalSlotsOp>(
moduleInitializer.getLoc(),
ArrayAttr::get(module.getContext(), slotSymNames), initialValues);
return success();
}
//===----------------------------------------------------------------------===//
// Monomorphization.
//===----------------------------------------------------------------------===//
@ -596,7 +588,13 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
instances[classType].push_back(nnModule);
}
// Step 2: Verify all functions are suitable to be analyzed by our later code.
// Step 2: Create the torch.global_slot.module_initializer op.
if (failed(createGlobalSlotModuleInitializer(module, symbolTable,
objectGraphInfo)))
return failure();
// Step 3: Verify all functions are suitable to be analyzed by our later code.
// This eliminates special handling / error code later.
//
// This is important, because in principle, we can perform arbitrarily complex
@ -608,7 +606,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
return failure();
}
// Step 3: Calculate the set of monomorphized functions that need to be
// Step 4: Calculate the set of monomorphized functions that need to be
// created. For each call that passes !torch.nn.Module to a function, we need
// to create a specialized version of that function just for that instance (or
// combination of instances in the case of multiple arguments).
@ -633,7 +631,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
return failure();
}
// Step 4: Clone/rewrite functions to implement the necessary
// Step 5: Clone/rewrite functions to implement the necessary
// monomorphizations.
DenseMap<Monomorphization, func::FuncOp> newFuncs;
int uniquifier = 0;
@ -672,13 +670,13 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
return failure();
}
// Step 5: Clean up object graph.
// Step 6: Clean up object graph.
DenseSet<func::FuncOp> liveFuncs;
for (auto &kv : newFuncs) {
liveFuncs.insert(kv.second);
}
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
if (isa<GlobalSlotOp>(&op))
if (isa<GlobalSlotOp, GlobalSlotModuleInitializerOp>(&op))
continue;
if (auto func = dyn_cast<func::FuncOp>(op)) {
if (liveFuncs.contains(func))

View File

@ -6,83 +6,429 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
//
// This file implements an optimistic dataflow analysis that proves that values
// used in global slot initializers are "safe" (see definition below). This
// analysis allows us to inline global slot initializers.
//
// One thing to note is that this inlining (as with all inlining) can create
// duplicate ops. That is usually not a problem, except for certain large
// tensor literals. We rely on later CSE passes to deduplicate those literals.
//
// For debugging this pass an effort has been made for
// `-debug-only=dataflow` and `-debug-only=torch-inline-global-slots` to give a
// good experience. When debugging this pass, it is recommended to start with
// `-debug-only=torch-inline-global-slots` to find values that are marked
// unsafe unexpectedly and then `-debug-only=dataflow` to find why.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "torch-inline-global-slots"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
/// A program point representing a symbol.
///
/// In principle we could use the `Operation *` program point of the Symbol op,
/// but that just adds a layer of indirection through a symbol table for the
/// purpose of this analysis.
///
/// This is easier because we only support FlatSymbolRefAttr's in Torch-MLIR in
/// a single module. If we had to support complex nested symbol references, we
/// would probably want to go through the effort to indirect through the symbol
/// tables to make things clearer.
class FlatSymbolRefProgramPoint
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
FlatSymbolRefAttr> {
public:
using Base::Base;
void print(raw_ostream &os) const override {
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
}
Location getLoc() const override {
return UnknownLoc::get(getValue().getContext());
}
};
static bool isTypeTriviallySafe(Type type) {
return type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>();
}
static bool isUseTreatedWithValueSemantics(OpOperand &use) {
Operation *op = use.getOwner();
// If the op unconditionally has value semantics, then the use has value
// semantics.
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>())
return true;
// The condition of the torch.prim.if op is treated with value semantics.
if (isa<PrimIfOp>(op) && use.getOperandNumber() == 0)
return true;
// TODO: Generalize the HasValueSemantics trait to support
// operand/result-granularity.
return false;
}
/// State tracking if an IR construct is "safe".
///
/// This state is tracked on Value's and also on global slots (via a
/// FlatSymbolRefProgramPoint).
///
/// In this context, "safe" means that the object is safe to inline.
/// This covers a few concepts
/// - the value cannot be mutated by the program
/// - the value cannot be potentially aliased, with that alias itself being
/// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {}
bool isUninitialized() const override {
// We are an optimistic analysis, so we are always default initialized to
// the optimistic "assumed safe" state.
return false;
}
ChangeResult defaultInitialize() override {
// We are an optimistic analysis, so the default state is always "safe".
return setSafe();
}
void print(raw_ostream &os) const override {
os << "InlineGlobalSlotsAnalysisState(" << (isSafe ? "safe" : "unsafe")
<< ")";
}
/// Helper for setting the state with the correct ChangeResult.
ChangeResult setSafe(bool newIsSafe = true) {
// As an optimistic analysis, once we prove that a value is unsafe, nothing
// can prove that it is safe again. This is the monotonicity property of
// the dataflow analysis that guarantees that we reach a fixed-point.
// If that property doesn't hold, then there is a bug in the analysis.
assert(!(isSafe == false && newIsSafe == true) && "non-monotonic update");
if (isSafe == newIsSafe)
return ChangeResult::NoChange;
isSafe = newIsSafe;
return ChangeResult::Change;
}
/// Helper for updatating the state with the correct ChangeResult based on the
/// safety of a use.
ChangeResult
incorporateSafetyOfUse(const InlineGlobalSlotsAnalysisState *useState) {
// The use is safe, so no need to change anything.
if (useState->isSafe)
return ChangeResult::NoChange;
return setSafe(false);
}
/// This is an optimistic analysis. We start assuming everything is safe.
bool isSafe = true;
};
class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {
public:
InlineGlobalSlotsAnalysis(DataFlowSolver &solver);
LogicalResult initialize(Operation *top) override;
LogicalResult visit(ProgramPoint point) override;
private:
/// The local transfer function determining the safety of `value`.
bool isValueSafeTransferFunction(Value value);
/// The InitializeGlobalSlotsOp of the current module we are analyzing.
///
/// This is used to propagate the analysis from globals into to the module
/// initializer.
InitializeGlobalSlotsOp initializeGlobalSlotsOp;
};
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
registerPointKind<FlatSymbolRefProgramPoint>();
}
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
auto walkResult = top->walk([this](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.sym_nameAttr())));
propagateIfChanged(state,
state->setSafe(globalSlot.getVisibility() !=
SymbolTable::Visibility::Public));
}
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(globalSlotSet.slotAttr()));
propagateIfChanged(state, state->setSafe(false));
}
// Save the InitializeGlobalSlotsOp for later referencee
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize;
}
for (Value result : op->getResults()) {
if (failed(visit(result)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return failure();
return success();
}
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = point.dyn_cast<Value>()) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));
// Handle GlobalSlotGetOp's.
if (auto opResult = value.dyn_cast<OpResult>()) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotGet.slotAttr());
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
flatSymbolRefPoint, globalSlotGet.result());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
}
return success();
}
if (auto *genericProgramPoint = point.dyn_cast<GenericProgramPoint *>()) {
if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) {
auto it =
llvm::find(initializeGlobalSlotsOp.slotSymNames(),
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
Value value = initializeGlobalSlotsOp->getOperand(
std::distance(initializeGlobalSlotsOp.slotSymNames().begin(), it));
auto *flatSymbolRefState =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
flatSymbolRefPoint);
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(valueState,
valueState->setSafe(flatSymbolRefState->isSafe));
}
return success();
}
}
LLVM_DEBUG(
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
return failure();
}
// This is only a member function to access protected get* functions.
bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
if (isTypeTriviallySafe(value.getType()))
return true;
for (OpOperand &use : value.getUses()) {
Operation *op = use.getOwner();
if (isUseTreatedWithValueSemantics(use))
continue;
// If the op is read-only and all results are safe, then this value is
// safe. This covers, for example, view-like ops that create aliases.
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() ||
MemoryEffectOpInterface::hasNoEffect(op)) &&
llvm::all_of(op->getResults(), [&](Value result) {
auto *state =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
return state->isSafe;
}))
continue;
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
auto symName = initialize.slotSymNames()[use.getOperandNumber()]
.cast<FlatSymbolRefAttr>();
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
if (state->isSafe)
continue;
}
// We may not create all the dependency edges, but that is ok since at
// this point we have already reached the fixed-point.
return false;
}
return true;
}
SmallVector<Operation *> getBackwardSliceIncludingRoot(Value initialValue) {
SetVector<Operation *> sliceSet;
getBackwardSlice(initialValue, &sliceSet);
SmallVector<Operation *> slice;
llvm::append_range(slice, sliceSet);
slice.push_back(initialValue.getDefiningOp());
return slice;
}
static bool isInitialValueTransitivelySafeToInline(Value initialValue,
DataFlowSolver &solver) {
SmallVector<Operation *> slice = getBackwardSliceIncludingRoot(initialValue);
for (Operation *op : slice) {
for (auto result : op->getResults()) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(result);
if (!state->isSafe) {
return false;
}
}
}
return true;
}
namespace {
class InlineGlobalSlotsPass
: public InlineGlobalSlotsBase<InlineGlobalSlotsPass> {
void runOnOperation() override {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
auto uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
if (!uses) {
module.emitError() << "cannot analyze symbol uses";
DataFlowSolver solver;
solver.load<InlineGlobalSlotsAnalysis>();
if (failed(solver.initializeAndRun(module)))
return signalPassFailure();
LLVM_DEBUG({
module->walk([&](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.sym_nameAttr())));
state->print(llvm::dbgs());
llvm::dbgs() << ": "
<< FlatSymbolRefAttr::get(globalSlot.sym_nameAttr())
<< "\n";
return;
}
if (op->getNumResults() != 1)
return;
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
op->getResult(0));
state->print(llvm::dbgs());
llvm::dbgs() << ": ";
op->dump();
});
});
Torch::InitializeGlobalSlotsOp initialize;
// TODO: Have a torch.module with an optional initializer region to make
// this tighter.
for (auto moduleInitializer :
module.getOps<Torch::GlobalSlotModuleInitializerOp>()) {
initialize = cast<Torch::InitializeGlobalSlotsOp>(
moduleInitializer.getBody()->getTerminator());
}
// Find all the global slots potentially written from within the module.
// (we handle the case of non-private symbols later).
DenseSet<Torch::GlobalSlotOp> potentiallyWrittenGlobalSlots;
for (const SymbolTable::SymbolUse &use : *uses) {
auto flatSymbolRef = use.getSymbolRef().dyn_cast<FlatSymbolRefAttr>();
if (!flatSymbolRef) {
use.getUser()->emitError() << "unimplemented: nested SymbolRef's";
return signalPassFailure();
if (!initialize) {
return;
}
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
auto slotSymName = initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>();
Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>());
auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the
// main dataflow analysis, so we need to check the slot's
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
// For example, a public !torch.int is not safe to inline, even though
// it is a value-semantic type and so the actual initializer value
// itself is conceptually safe to inline.
if (!state->isSafe) {
continue;
}
// Check to see if the initializing value is safe to inline.
// This requires a transitive check of all subobjects.
// TODO: This would really be more logical to do as a forward dataflow
// analyis on the whole module initializer rather than doing the
// transitive check backward for each initial value. But it is just
// too much boilerplate to write that with the dataflow framework and we
// generally don't expect long transitive chains of values here -- most
// initial values are just single tensor literals.
if (isInitialValueTransitivelySafeToInline(operand, solver)) {
safeToInline.insert(slotSymName);
}
auto globalSlot =
symbolTable.lookup<Torch::GlobalSlotOp>(flatSymbolRef.getValue());
if (!globalSlot)
continue;
if (isa<Torch::GlobalSlotGetOp>(use.getUser()))
continue;
potentiallyWrittenGlobalSlots.insert(globalSlot);
}
SymbolTable symbolTable(module);
DenseSet<Operation *> toErase;
// Inline all the global slots that are not potentially written.
for (const SymbolTable::SymbolUse &use : *uses) {
auto flatSymbolRef = use.getSymbolRef().cast<FlatSymbolRefAttr>();
auto globalSlot =
symbolTable.lookup<Torch::GlobalSlotOp>(flatSymbolRef.getValue());
if (!globalSlot)
continue;
// And external user might write to the global slot.
if (!globalSlot.isPrivate())
continue;
// An internal user exists which might write to the global slot.
if (potentiallyWrittenGlobalSlots.contains(globalSlot))
continue;
auto globalSlotGet = cast<Torch::GlobalSlotGetOp>(use.getUser());
OpBuilder builder(globalSlotGet);
BlockAndValueMapping mapper;
for (Operation &op : globalSlot.getBody()->without_terminator())
builder.clone(op, mapper);
Value cloned = mapper.lookup(
cast<GlobalSlotInitOp>(globalSlot.getBody()->getTerminator())
.getOperand());
globalSlotGet.replaceAllUsesWith(cloned);
toErase.insert(globalSlotGet);
toErase.insert(globalSlot);
}
module.walk([&](Torch::GlobalSlotGetOp op) {
if (!safeToInline.count(op.slotAttr()))
return;
// TODO: Make this more ergonomic.
auto it = llvm::find(initialize.slotSymNames(), op.slotAttr());
Value initialValue = initialize.getOperand(
std::distance(initialize.slotSymNames().begin(), it));
// It seems inefficient to get a backward slice again here, but we are
// going to be cloning the whole slice anyway, so it doesn't seem like a
// big deal.
SmallVector<Operation *> slice =
getBackwardSliceIncludingRoot(initialValue);
BlockAndValueMapping mapping;
OpBuilder builder(op);
for (Operation *opInSlice : slice)
builder.clone(*opInSlice, mapping);
auto inlinedInitialValue = mapping.lookup(initialValue);
inlinedInitialValue = Torch::adjustStaticInformation(
builder, op.getLoc(), inlinedInitialValue, op.getType(),
/*userAllowsRefinement=*/false);
op.replaceAllUsesWith(inlinedInitialValue);
toErase.insert(op);
});
// Clean up after the transform.
// Erase any pending ops.
for (Operation *op : toErase)
op->erase();
// Erase any global slots that we inlined.
// This could be left to SymbolDCE but it's not hard to do here.
for (FlatSymbolRefAttr symName :
llvm::map_range(safeToInline, [](Attribute attr) {
return attr.cast<FlatSymbolRefAttr>();
})) {
auto globalSlot =
symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue());
globalSlot.erase();
}
// Update the initializer.
SmallVector<Attribute> newSlotSymNames;
SmallVector<Value> newInitialValues;
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
auto slotSymName = initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>();
if (!safeToInline.count(slotSymName)) {
newSlotSymNames.push_back(slotSymName);
newInitialValues.push_back(initialize.getOperand(i));
}
}
{
OpBuilder builder(initialize);
builder.create<Torch::InitializeGlobalSlotsOp>(
initialize.getLoc(),
ArrayAttr::get(module.getContext(), newSlotSymNames),
newInitialValues);
}
initialize.erase();
}
};
} // namespace

View File

@ -91,6 +91,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Incorporate user annotations and remove signature Python-isms.
pm.addPass(createAdjustCallingConventionsPass());
// TODO: Remove options.optimize and this OPT-ONLY stuff -- we are already way
// past the point of no return for it being necessary for functional
// correctness.
if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions
@ -102,6 +105,22 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// Also don't rely on this pass to expose constants into the program to
// simplify handling of "optional".
pm.addPass(createInlineGlobalSlotsPass());
// After doing a first round of inlining global slots, canonicalize again to
// take advantage of optimization opportunities exposed by the inlined
// global slots. In particular, this is functionally necessary now because
// large amounts of control flow are guarded by an "is training" flag, so
// inlining removes certain mutating operations done on the slots enabling
// them to be deleted.
// TODO: In full generality, we need to do a fixed-point iteration of
// shape inference, maximizing value semantics, decomposition, inling global
// slots, and canonicalization.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Inline again, cleaning up any remaining global slots that might be dead
// now.
pm.addPass(createInlineGlobalSlotsPass());
// Erase the module initializers (or fail compilation), since they aren't
// permitted in our backend contract at the moment.
pm.addPass(Torch::createEraseModuleInitializerPass());
}
// Reduce variants of ops to a smaller set of primitives.

View File

@ -42,6 +42,14 @@ class VerifyInvariantsBeforeBackendLoweringPass
: public VerifyInvariantsBeforeBackendLoweringBase<
VerifyInvariantsBeforeBackendLoweringPass> {
void runOnOperation() override {
if (getOperation()
.walk([](Torch::GlobalSlotModuleInitializerOp op) {
op.emitError()
<< "unsupported by backend lowering: module initializers";
return WalkResult::interrupt();
})
.wasInterrupted())
return signalPassFailure();
auto walkResult = getOperation().walk([&](Block *block) {
// Check invariants on all the Value's in the program.
// That is, check all BlockArgument's and OpResult's.

View File

@ -2,26 +2,22 @@
// Basic case.
// CHECK-LABEL: torch.global_slot @b : !torch.bool {
// CHECK: %[[INIT:.*]] = torch.constant.bool true
// CHECK: torch.global_slot.init %[[INIT]] : !torch.bool
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[FLOAT4:.*]] = torch.constant.float 4.250000e+01
// CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: torch.initialize.global_slots [
// CHECK: @b(%[[TRUE]] : !torch.bool)
// CHECK: @i(%[[INT3]] : !torch.int)
// CHECK: @f(%[[FLOAT4]] : !torch.float)
// CHECK: @t(%[[TENSOR]] : !torch.tensor)
// CHECK: ]
// CHECK: }
// CHECK-LABEL: torch.global_slot @i : !torch.int {
// CHECK: %[[INIT:.*]] = torch.constant.int 3
// CHECK: torch.global_slot.init %[[INIT]] : !torch.int
// CHECK: }
// CHECK-LABEL: torch.global_slot @f : !torch.float {
// CHECK: %[[INIT:.*]] = torch.constant.float 4.250000e+01
// CHECK: torch.global_slot.init %[[INIT]] : !torch.float
// CHECK: }
// CHECK-LABEL: torch.global_slot @t : !torch.tensor {
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
// CHECK: }
// CHECK-LABEL: torch.global_slot @b : !torch.bool
// CHECK-LABEL: torch.global_slot @i : !torch.int
// CHECK-LABEL: torch.global_slot @f : !torch.float
// CHECK-LABEL: torch.global_slot @t : !torch.tensor
torch.class_type @c {
torch.attr "b" : !torch.bool
torch.attr "i" : !torch.int

View File

@ -35,43 +35,3 @@ func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">,
%2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float
return
}
// -----
torch.class_type @c {
torch.attr "t1" : !torch.tensor
torch.attr "t2" : !torch.tensor
}
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
torch.nn_module {
torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c">
func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
%cst = torch.constant.int 1
%ret = torch.aten.add.Tensor %t1, %t2, %cst : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
torch.class_type @c {
torch.attr "t1" : !torch.tensor
torch.attr "t2" : !torch.tensor
}
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
torch.nn_module {
torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c">
func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
return
}

View File

@ -1,13 +1,16 @@
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
// CHECK that multiple nested initialization ops are properly handled.
// Check that multiple nested initialization ops are properly handled.
// CHECK-LABEL: torch.global_slot @l : !torch.list<list<list<tensor>>> {
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<tensor>
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<tensor>, !torch.list<tensor>) -> !torch.list<list<tensor>>
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<list<tensor>>, !torch.list<list<tensor>>) -> !torch.list<list<list<tensor>>>
// CHECK: torch.global_slot.init %[[L2]] : !torch.list<list<list<tensor>>>
// CHECK: torch.initialize.global_slots [
// CHECK: @l(%[[L2]] : !torch.list<list<list<tensor>>>)
// CHECK: ]
// CHECK: }
// CHECK-LABEL: torch.global_slot @l : !torch.list<list<list<tensor>>>
torch.class_type @c {
torch.attr "l" : !torch.list<list<list<tensor>>>

View File

@ -12,20 +12,22 @@ torch.class_type @__torch__.Submodule {
torch.method private "forward", @__torch__.Submodule.forward
}
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: torch.initialize.global_slots [
// CHECK: @s1.n(%[[INT1]] : !torch.int)
// CHECK: @s2.n(%[[INT2]] : !torch.int)
// CHECK: ]
// CHECK: }
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int
%int1 = torch.constant.int 1
%s1 = torch.nn_module {
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: torch.global_slot.init %[[C1]] : !torch.int
// CHECK: }
torch.slot "n", %int1 : !torch.int
} : !torch.nn.Module<"__torch__.Submodule">
%int2 = torch.constant.int 2
%s2 = torch.nn_module {
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int {
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: torch.global_slot.init %[[C2]] : !torch.int
// CHECK: }
torch.slot "n", %int2 : !torch.int
} : !torch.nn.Module<"__torch__.Submodule">
%3 = torch.nn_module {

View File

@ -10,20 +10,23 @@ torch.class_type @__torch__.Submodule {
torch.method private "forward", @__torch__.Submodule.forward
}
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: torch.initialize.global_slots [
// CHECK: @s1.n(%[[INT1]] : !torch.int)
// CHECK: @s2.n(%[[INT2]] : !torch.int)
// CHECK: ]
// CHECK: }
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int
%int1 = torch.constant.int 1
%s1 = torch.nn_module {
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: torch.global_slot.init %[[C1]] : !torch.int
// CHECK: }
torch.slot "n", %int1 : !torch.int
} : !torch.nn.Module<"__torch__.Submodule">
%int2 = torch.constant.int 2
%s2 = torch.nn_module {
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int {
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: torch.global_slot.init %[[C2]] : !torch.int
// CHECK: }
torch.slot "n", %int2 : !torch.int
} : !torch.nn.Module<"__torch__.Submodule">
%3 = torch.nn_module {

View File

@ -2,10 +2,13 @@
// Check that linkage names consist of the dotted path from the root.
// CHECK-LABEL: torch.global_slot @m.float : !torch.float {
// CHECK: %[[INIT:.*]] = torch.constant.float 4.200000e+01
// CHECK: torch.global_slot.init %[[INIT]] : !torch.float
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[FLOAT:.*]] = torch.constant.float 4.200000e+01
// CHECK: torch.initialize.global_slots [
// CHECK: @m.float(%[[FLOAT]] : !torch.float)
// CHECK: ]
// CHECK: }
// CHECK-LABEL: torch.global_slot @m.float : !torch.float
torch.class_type @child {

View File

@ -0,0 +1,20 @@
// RUN: torch-mlir-opt -torch-erase-module-initializer -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK: module {
// CHECK-NEXT: }
torch.global_slot.module_initializer {
torch.initialize.global_slots [
]
}
// -----
torch.global_slot @slot0 : !torch.int
// expected-error@+1 {{could not erase non-empty module initializer}}
torch.global_slot.module_initializer {
%0 = torch.constant.int 0
torch.initialize.global_slots [
@slot0(%0: !torch.int)
]
}

View File

@ -0,0 +1,94 @@
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// Safety analysis aspect of the pass.
// -----
// Test case: Public slots cannot be inlined.
// Test case: Set slots cannot be inlined.
// CHECK: torch.global_slot @public : !torch.int
// CHECK: torch.global_slot "private" @set : !torch.int
torch.global_slot @public : !torch.int
torch.global_slot "private" @set : !torch.int
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: torch.initialize.global_slots [
// CHECK: @public(%[[C1]] : !torch.int)
// CHECK: @set(%[[C1]] : !torch.int)
// CHECK: ]
// CHECK: }
torch.global_slot.module_initializer {
%0 = torch.constant.int 1
torch.initialize.global_slots [
@public(%0 : !torch.int)
@set(%0 : !torch.int)
]
}
func.func @forward() {
%0 = torch.constant.int 2
torch.global_slot.set @set = %0 : !torch.int
return
}
// -----
// Test case: Propagate safety transitively through ops without HasValueSemantics.
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
func.func @forward() {
%int0 = torch.constant.int 0
%0 = torch.global_slot.get @list : !torch.list<tensor>
%1 = torch.aten.__getitem__.t %0, %int0 : !torch.list<tensor>, !torch.int -> !torch.tensor
%2 = torch.aten.mul.Tensor %1, %1 : !torch.tensor, !torch.tensor -> !torch.tensor
return
}
// -----
// Test case: An unsafe subobject (@tensor) blocks inlining of the containing object (@list).
// Note that we can check just the initializer -- if we inlined the slot, then
// we would have eliminated the slot from the initializer.
// Also, the initializer is verified to match the list of global slots in the
// module. So it is a nice one-stop-shop.
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list : !torch.list<tensor>
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT: @tensor(%{{.*}} : !torch.tensor)
// CHECK-NEXT: @list(%{{.*}} : !torch.list<tensor>)
// CHECK-NEXT: ]
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list(%1 : !torch.list<tensor>)
]
}
func.func @forward() {
%int0 = torch.constant.int 0
%0 = torch.global_slot.get @list : !torch.list<tensor>
%tensor = torch.global_slot.get @tensor : !torch.tensor
torch.aten.relu_ %tensor : !torch.tensor -> !torch.tensor
return
}

View File

@ -0,0 +1,81 @@
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// Transform aspect of the pass.
// Test case: Most basic case that can be inlined.
// CHECK-NOT: @slot0
torch.global_slot "private" @slot0 : !torch.int
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.constant.int 1
torch.initialize.global_slots [
@slot0(%0 : !torch.int)
]
}
// CHECK-LABEL: func.func @forward() {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: return
func.func @forward() {
%0 = torch.global_slot.get @slot0 : !torch.int
return
}
// -----
// Test case: Shared objects in object graph shared between two initial values.
torch.global_slot "private" @tensor : !torch.tensor
torch.global_slot "private" @list_of_tensor : !torch.list<tensor>
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor)
@list_of_tensor(%1 : !torch.list<tensor>)
]
}
// CHECK-LABEL: func.func @forward() {
// CHECK: %[[T0:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor
// CHECK: %[[T1:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[T1]] : (!torch.tensor) -> !torch.list<tensor>
// CHECK: return
func.func @forward() {
%0 = torch.global_slot.get @tensor : !torch.tensor
%1 = torch.global_slot.get @list_of_tensor : !torch.tensor
return
}
// -----
// Test case: Adjusting static info.
// CHECK-NOT: @tensor
torch.global_slot "private" @tensor : !torch.tensor
// CHECK-LABEL: torch.global_slot.module_initializer {
// CHECK: torch.initialize.global_slots [
// CHECK-NEXT ]
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],f32>
torch.initialize.global_slots [
@tensor(%0 : !torch.tensor<[],f32>)
]
}
// CHECK-LABEL: func.func @forward() {
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor<[],f32>
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.tensor<[],f32> to !torch.tensor
func.func @forward() {
%0 = torch.global_slot.get @tensor : !torch.tensor
return
}

View File

@ -1,37 +0,0 @@
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
// CHECK-NOT: @readonly
torch.global_slot "private" @readonly : !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<1xf32>) : !torch.tensor
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: torch.global_slot @public
torch.global_slot @public : !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<2xf32>) : !torch.tensor
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: torch.global_slot "private" @mutated
torch.global_slot "private" @mutated : !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
// Inlined.
// CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
%0 = torch.global_slot.get @readonly : !torch.tensor
// Not inlined: potentially mutated by externals.
// CHECK: %[[PUBLIC:.*]] = torch.global_slot.get @public : !torch.tensor
%1 = torch.global_slot.get @public : !torch.tensor
// Not inlined: potentially mutated internally.
// CHECK: torch.global_slot.set @mutated = %[[READONLY]] : !torch.tensor
// CHECK: %[[MUTATED:.*]] = torch.global_slot.get @mutated : !torch.tensor
torch.global_slot.set @mutated = %0 : !torch.tensor
%2 = torch.global_slot.get @mutated : !torch.tensor
// CHECK: return %[[READONLY]], %[[PUBLIC]], %[[MUTATED]] : !torch.tensor, !torch.tensor, !torch.tensor
return %0, %1, %2 : !torch.tensor, !torch.tensor, !torch.tensor
}

View File

@ -179,3 +179,89 @@ func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
return %1 : !torch.vtensor<[1],f32>
}
// -----
// There must be only one module initialize.
torch.global_slot.module_initializer {
torch.initialize.global_slots [
]
}
// expected-error @+1 {{there must be only one global slot initializer}}
torch.global_slot.module_initializer {
torch.initialize.global_slots [
]
}
// -----
// Initialized slot missing, or or non-existent slots initialized.
// expected-note @+1 {{missing global slot initializer for @slot0}}
torch.global_slot @slot0 : !torch.int
// expected-note @+1 {{missing global slot initializer for @slot1}}
torch.global_slot @slot1 : !torch.int
torch.global_slot.module_initializer {
%0 = torch.constant.int 1
%1 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%2 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],unk>
// expected-error @below {{must have one initializer for each global slot in the module}}
// expected-note @below {{unexpected global slot initializer for non-existent global slot @nonexistent_slot0}}
// expected-note @below {{unexpected global slot initializer for non-existent global slot @nonexistent_slot1}}
torch.initialize.global_slots [
@nonexistent_slot0(%0 : !torch.int)
@nonexistent_slot1(%0 : !torch.int)
]
}
// -----
// Duplicate initialization of global slot.
torch.global_slot @slot0 : !torch.int
torch.global_slot.module_initializer {
%0 = torch.constant.int 1
// expected-error @+1 {{duplicate initialization of global slot: @slot0}}
torch.initialize.global_slots [
@slot0(%0 : !torch.int)
@slot0(%0 : !torch.int)
]
}
// -----
// Subtyping checks.
torch.global_slot @tensor : !torch.tensor
torch.global_slot @initialized_with_refined : !torch.tensor
torch.global_slot @error_initialized_with_derefined : !torch.tensor<[],unk>
torch.global_slot.module_initializer {
%1 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
%2 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],unk>
// expected-error @below {{initial value for global slot @error_initialized_with_derefined has type '!torch.tensor' which is not within the bound '!torch.tensor<[],unk>'}}
torch.initialize.global_slots [
@tensor(%1 : !torch.tensor)
@initialized_with_refined(%2 : !torch.tensor<[],unk>)
@error_initialized_with_derefined(%1 : !torch.tensor)
]
}
// -----
// Restricted set of ops in the module initializer.
torch.global_slot @tensor : !torch.tensor
torch.global_slot.module_initializer {
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
// expected-error @+1 {{'torch.aten.mul.Tensor' op is not allowed in a module initializer}}
%1 = torch.aten.mul.Tensor %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor
torch.initialize.global_slots [
@tensor(%1 : !torch.tensor)
]
}

View File

@ -26,3 +26,11 @@ func.func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int)
torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32>
return
}
// -----
// expected-error@+1 {{unsupported by backend lowering: module initializers}}
torch.global_slot.module_initializer {
torch.initialize.global_slots [
]
}