mirror of https://github.com/llvm/torch-mlir
Rework how global slot initializers work.
Rather than a per-global-slot initializer region, we now have one for the whole module. For example, it might look like this: ``` torch.global_slot "private" @tensor : !torch.tensor torch.global_slot "private" @list : !torch.list<tensor> torch.global_slot.module_initializer { %0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor %1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor> torch.initialize.global_slots [ @tensor(%0 : !torch.tensor) @list(%1 : !torch.list<tensor>) ] } ``` This new structure allows GlobalizeObjectGraph to create the initializer in a much simpler way, avoiding the need to reason about whether different slots alias each other. Reasoning about whether slots alias each other now is the responsibility of InlineGlobalSlots, which has to do a much more complicated analysis, implemented using MLIR's dataflow analysis framework. Recommended review order: - Check out the new IR constructs in the .mlir files of various passes - Op definitions (*.td) - Changes to GlobalizeObjectGraph pass. - InlineGlobalSlots pass (~total rewrite) - Misc changes: - Moving torchMlirAdjustStaticInformation for sharing with C++ code. - EraseModuleInitializer pass To make this a bit nicer, it would be good to have a `torch.module` op with an initializer region attached. That would be more invasive though. This change has highlighted certain aspects of our project layering which are worth calling out. None of our backends can handle global slots, so we enforce that there are no global slots before backend lowering. At an earlier stage in the project, we had aspirations of transparently handling mutable global state and such, but for reasons described below, that is no longer a goal. So really global slots should be seen as a progressive lowering step as part of inlining all the IValue's in the original program (GlobalizeObjectGraph is also one such step). Over time, with insights from work like IREE-JAX, it has become clear that there isn't a reliable programming model we can compile for users where we just transparently handle mutable global state (and some other things, like lists and dictionaries). There is a need for an "outer program" that orchestrates more restricted subroutines of the kind we can handle in our compile flow here. The benefit of that is that it decouples considerations like shapes, dtypes, etc. from the program constructs used in the outer program. As long as the outer program can efficiently invoke (pipelining/async/etc.) high-performance data-parallel numerical subroutines of the kind we compile in our flow here, then there is a complete programming model. This is also consistent with the direction of upstream PyTorch which is becoming more tracing-based (which inherently loses a lot of program structure, which then has to be applied back with an "outer program" orchestrating the traced subroutines).pull/1187/head
parent
34e207eeb5
commit
504de5e701
|
@ -187,6 +187,15 @@ m_TorchTensorSizeInt(Value tensor, int64_t *dim) {
|
||||||
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
|
Value copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType,
|
||||||
Value tensor);
|
Value tensor);
|
||||||
|
|
||||||
|
/// Adjusts the static information in the type of `value` to `desiredType`.
|
||||||
|
///
|
||||||
|
/// Returns null if such an adjustment is not possible.
|
||||||
|
///
|
||||||
|
/// If `userAllowsRefinement` is true, then the original value will be returned
|
||||||
|
/// if it is a subtype of `desiredType`.
|
||||||
|
Value adjustStaticInformation(OpBuilder &builder, Location loc, Value value,
|
||||||
|
Type desiredType, bool userAllowsRefinement);
|
||||||
|
|
||||||
/// Returns true if `list` is potentially mutated.
|
/// Returns true if `list` is potentially mutated.
|
||||||
bool isListPotentiallyMutated(Value list);
|
bool isListPotentiallyMutated(Value list);
|
||||||
|
|
||||||
|
|
|
@ -228,8 +228,6 @@ def Torch_AttrOp : Torch_Op<"attr", [
|
||||||
|
|
||||||
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||||
Symbol,
|
Symbol,
|
||||||
IsolatedFromAbove,
|
|
||||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::GlobalSlotInitOp">
|
|
||||||
]> {
|
]> {
|
||||||
let summary = "A slot with global storage";
|
let summary = "A slot with global storage";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -245,17 +243,66 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||||
TypeAttr:$typeBound
|
TypeAttr:$typeBound
|
||||||
);
|
);
|
||||||
let results = (outs);
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
($sym_visibility^)? $sym_name attr-dict `:` $typeBound
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [
|
||||||
|
IsolatedFromAbove,
|
||||||
|
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">
|
||||||
|
]> {
|
||||||
|
let summary = "Module initializer for all `torch.global_slot` ops";
|
||||||
|
let description = [{
|
||||||
|
Initializer function that runs once at program startup to initialize
|
||||||
|
all `torch.global_slot` ops in the module.
|
||||||
|
|
||||||
|
The only ops that should be in the module initializer should be ops
|
||||||
|
generated by the IValue importer. This set avoids the need to define
|
||||||
|
the behavior in case of certain kinds of side effects in the initializer
|
||||||
|
(except for the side effect of updating the torch.global_slot ops with the
|
||||||
|
`torch.initialize.global_slots` op).
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
let results = (outs);
|
||||||
let regions = (region SizedRegion<1>:$initializer);
|
let regions = (region SizedRegion<1>:$initializer);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
($sym_visibility^)? $sym_name attr-dict `:` $typeBound ($initializer^)?
|
$initializer attr-dict
|
||||||
}];
|
}];
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
|
||||||
|
Terminator,
|
||||||
|
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> {
|
||||||
|
let summary = "Terminator for torch.global_slot.module_initializer region";
|
||||||
|
let description = [{
|
||||||
|
Atomically updates the value of all the global slots named in `slotSymNames`
|
||||||
|
with the corresponding values provided in `initialValues`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SymbolRefArrayAttr:$slotSymNames,
|
||||||
|
Variadic<AnyTorchType>:$initialValues
|
||||||
|
);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
// This builder creates an illegal op, but is needed to appease
|
||||||
|
// ensureTerminator in the default builders for SingleBlockImplicitTerminator
|
||||||
|
// on the parent op.
|
||||||
|
// TODO: Have a SingleBlockExplicitTerminator trait.
|
||||||
|
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
def Torch_GlobalSlotInitOp : Torch_Op<"global_slot.init", [
|
||||||
Terminator,
|
Terminator,
|
||||||
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
HasParent<"::mlir::torch::Torch::GlobalSlotOp">]> {
|
||||||
let summary = "yield-like terminator for torch.global_slot initializer region";
|
let summary = "yield-like terminator for torch.initialize.global_slotsr region";
|
||||||
let description = [{
|
let description = [{
|
||||||
The operand to this op becomes the initial value of the parent
|
The operand to this op becomes the initial value of the parent
|
||||||
torch.global_slot.
|
torch.global_slot.
|
||||||
|
|
|
@ -80,6 +80,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyConversionToValueSemanticsPass();
|
createVerifyConversionToValueSemanticsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createEraseModuleInitializerPass();
|
||||||
|
|
||||||
StringRef getShapeLibrary();
|
StringRef getShapeLibrary();
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
|
@ -143,7 +143,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||||
|
|
||||||
Note: This pass inlines everything that is safe to inline. That is, it
|
Note: This pass inlines everything that is safe to inline. That is, it
|
||||||
doesn't have a cost model. This is likely to pessimize programs with
|
doesn't have a cost model. This is likely to pessimize programs with
|
||||||
significant amounts of computation inside torch.global_slot initializer
|
significant amounts of computation inside torch.initialize.global_slotsr
|
||||||
regions (but this currently doesn't happen due to how TorchScript modules
|
regions (but this currently doesn't happen due to how TorchScript modules
|
||||||
are imported -- the contents are just constants).
|
are imported -- the contents are just constants).
|
||||||
}];
|
}];
|
||||||
|
@ -265,4 +265,17 @@ def VerifyConversionToValueSemantics
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def EraseModuleInitializer
|
||||||
|
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
|
||||||
|
let summary = "Erase the `torch.global_slot.module_initializer` op.";
|
||||||
|
let constructor =
|
||||||
|
"mlir::torch::Torch::createEraseModuleInitializerPass()";
|
||||||
|
let description = [{
|
||||||
|
Backends cannot currently handle module initializers, so we omit them from
|
||||||
|
our backend contract. This pass removes the
|
||||||
|
`torch.global_slot.module_initializer` op from the module if legal, or
|
||||||
|
raises an error.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCHMLIR_TORCH_PASSES
|
#endif // TORCHMLIR_TORCH_PASSES
|
||||||
|
|
|
@ -31,45 +31,8 @@ MlirValue torchMlirAdjustStaticInformation(MlirBlock block_,
|
||||||
OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_)));
|
OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_)));
|
||||||
builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator()
|
builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator()
|
||||||
: block->end());
|
: block->end());
|
||||||
|
|
||||||
Value value = unwrap(value_);
|
Value value = unwrap(value_);
|
||||||
Type type = value.getType();
|
|
||||||
Type desiredType = unwrap(desiredType_);
|
Type desiredType = unwrap(desiredType_);
|
||||||
|
return wrap(Torch::adjustStaticInformation(
|
||||||
// If the value is already of the desired type, we're done.
|
builder, value.getLoc(), value, desiredType, userAllowsRefinement));
|
||||||
if (type == desiredType)
|
|
||||||
return wrap(value);
|
|
||||||
|
|
||||||
// If the type is a tensor, then adjust the static information.
|
|
||||||
if ((type.isa<Torch::ValueTensorType>() &&
|
|
||||||
desiredType.isa<Torch::ValueTensorType>()) ||
|
|
||||||
(type.isa<Torch::NonValueTensorType>() &&
|
|
||||||
desiredType.isa<Torch::NonValueTensorType>())) {
|
|
||||||
Value adjusted = builder.create<Torch::TensorStaticInfoCastOp>(
|
|
||||||
value.getLoc(), desiredType, value);
|
|
||||||
return wrap(adjusted);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the type is a subtype of desiredType, then we need to derefine it to
|
|
||||||
// desiredType, unless the user allows refinement.
|
|
||||||
if (Torch::isValidSubtype(type, desiredType)) {
|
|
||||||
if (!userAllowsRefinement) {
|
|
||||||
Value adjusted =
|
|
||||||
builder.create<Torch::DerefineOp>(value.getLoc(), desiredType, value);
|
|
||||||
return wrap(adjusted);
|
|
||||||
} else {
|
|
||||||
return wrap(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the desiredType is subtype of type, then we assume that the desiredType
|
|
||||||
// is dynamically valid, so we do an unchecked cast.
|
|
||||||
if (Torch::isValidSubtype(desiredType, type)) {
|
|
||||||
Value adjusted = builder.create<Torch::PrimUncheckedCastOp>(
|
|
||||||
value.getLoc(), desiredType, value);
|
|
||||||
return wrap(adjusted);
|
|
||||||
}
|
|
||||||
|
|
||||||
// No known adjustment.
|
|
||||||
return {};
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,49 @@ using namespace mlir::torch::Torch;
|
||||||
// Utilities
|
// Utilities
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
||||||
|
Location loc, Value value,
|
||||||
|
Type desiredType,
|
||||||
|
bool userAllowsRefinement) {
|
||||||
|
Type type = value.getType();
|
||||||
|
|
||||||
|
// If the value is already of the desired type, we're done.
|
||||||
|
if (type == desiredType)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
// If the type is a tensor, then adjust the static information.
|
||||||
|
if ((type.isa<ValueTensorType>() && desiredType.isa<ValueTensorType>()) ||
|
||||||
|
(type.isa<NonValueTensorType>() &&
|
||||||
|
desiredType.isa<NonValueTensorType>())) {
|
||||||
|
Value adjusted = builder.create<TensorStaticInfoCastOp>(value.getLoc(),
|
||||||
|
desiredType, value);
|
||||||
|
return adjusted;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the type is a subtype of desiredType, then we need to derefine it to
|
||||||
|
// desiredType, unless the user allows refinement.
|
||||||
|
if (isValidSubtype(type, desiredType)) {
|
||||||
|
if (!userAllowsRefinement) {
|
||||||
|
Value adjusted =
|
||||||
|
builder.create<DerefineOp>(value.getLoc(), desiredType, value);
|
||||||
|
return adjusted;
|
||||||
|
} else {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the desiredType is subtype of type, then we assume that the desiredType
|
||||||
|
// is dynamically valid, so we do an unchecked cast.
|
||||||
|
if (isValidSubtype(desiredType, type)) {
|
||||||
|
Value adjusted =
|
||||||
|
builder.create<PrimUncheckedCastOp>(value.getLoc(), desiredType, value);
|
||||||
|
return adjusted;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No known adjustment.
|
||||||
|
return Value();
|
||||||
|
}
|
||||||
|
|
||||||
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
BaseTensorType newType,
|
BaseTensorType newType,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
|
@ -1936,3 +1979,154 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
||||||
return emitOpError("expected number of shapes to match number of results");
|
return emitOpError("expected number of shapes to match number of results");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GlobalSlotModuleInitializerOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
||||||
|
// We centralize all verification of the global slots and the
|
||||||
|
// InitializeGlobalSlotsOp into here, since it requires processing the whole
|
||||||
|
// module.
|
||||||
|
|
||||||
|
// TODO: We should really have a `torch.module` and have this initializer be
|
||||||
|
// a region attached to it.
|
||||||
|
|
||||||
|
ModuleOp module = cast<ModuleOp>(getOperation()->getParentOp());
|
||||||
|
for (auto op : module.getOps<GlobalSlotModuleInitializerOp>()) {
|
||||||
|
if (op.getOperation() != getOperation())
|
||||||
|
return op.emitError("there must be only one global slot initializer");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the relevant symbol names we will verify.
|
||||||
|
DenseSet</*StringAttr*/ Attribute> knownGlobalSlots;
|
||||||
|
for (auto op : module.getOps<GlobalSlotOp>())
|
||||||
|
knownGlobalSlots.insert(op.sym_nameAttr());
|
||||||
|
DenseSet</*StringAttr*/ Attribute> initializedGlobalSlots;
|
||||||
|
auto initialize = cast<InitializeGlobalSlotsOp>(getBody()->getTerminator());
|
||||||
|
for (Attribute symName : initialize.slotSymNames()) {
|
||||||
|
auto wasInserted = initializedGlobalSlots
|
||||||
|
.insert(symName.cast<FlatSymbolRefAttr>().getAttr())
|
||||||
|
.second;
|
||||||
|
if (!wasInserted)
|
||||||
|
return initialize.emitError("duplicate initialization of global slot: ")
|
||||||
|
<< symName;
|
||||||
|
}
|
||||||
|
auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) {
|
||||||
|
return lhs.cast<StringAttr>().getValue() <
|
||||||
|
rhs.cast<StringAttr>().getValue();
|
||||||
|
};
|
||||||
|
auto known = llvm::to_vector(knownGlobalSlots);
|
||||||
|
llvm::sort(known, lessThanByStringValue);
|
||||||
|
auto initialized = llvm::to_vector(initializedGlobalSlots);
|
||||||
|
llvm::sort(initialized, lessThanByStringValue);
|
||||||
|
|
||||||
|
// Check that the global slots in the module are all initialized.
|
||||||
|
SymbolTable symbolTable(module);
|
||||||
|
if (initializedGlobalSlots != knownGlobalSlots) {
|
||||||
|
InFlightDiagnostic diag = initialize.emitOpError(
|
||||||
|
"must have one initializer for each global slot in the module");
|
||||||
|
for (auto knownGlobalSlot : known) {
|
||||||
|
auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast<StringAttr>());
|
||||||
|
if (!initializedGlobalSlots.count(knownGlobalSlot)) {
|
||||||
|
diag.attachNote(
|
||||||
|
symbolTable.lookup<GlobalSlotOp>(symName.getAttr()).getLoc())
|
||||||
|
.append("missing global slot initializer for ", symName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto initializedGlobalSlot : initialized) {
|
||||||
|
if (!knownGlobalSlots.count(initializedGlobalSlot)) {
|
||||||
|
diag.attachNote().append(
|
||||||
|
"unexpected global slot initializer for non-existent global slot ",
|
||||||
|
FlatSymbolRefAttr::get(initializedGlobalSlot.cast<StringAttr>()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that initial values satisfy type bounds.
|
||||||
|
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
|
||||||
|
auto symName = initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||||
|
auto initialValue = initialize.getOperand(i);
|
||||||
|
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
|
||||||
|
if (!isValidSubtype(initialValue.getType(), globalSlotOp.typeBound())) {
|
||||||
|
return initialize.emitOpError().append(
|
||||||
|
"initial value for global slot ", symName, " has type ",
|
||||||
|
initialValue.getType(), " which is not within the bound ",
|
||||||
|
globalSlotOp.typeBound());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto walkResult = getOperation()->walk([](Operation *op) {
|
||||||
|
// We only permit a small set of ops in the module initializer.
|
||||||
|
// These ops are essentially those which can be produced by the IValue
|
||||||
|
// importer.
|
||||||
|
if (isa<GlobalSlotModuleInitializerOp, InitializeGlobalSlotsOp,
|
||||||
|
PrimListConstructOp, PrimDictConstructOp, PrimTupleConstructOp,
|
||||||
|
ConstantBoolOp, ConstantStrOp, ConstantIntOp, ConstantFloatOp,
|
||||||
|
ConstantNoneOp, NonValueTensorLiteralOp, PerTensorAffineCreateOp,
|
||||||
|
LinearParamsCreateOp>(op))
|
||||||
|
return WalkResult::advance();
|
||||||
|
op->emitOpError() << "is not allowed in a module initializer";
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
});
|
||||||
|
if (walkResult.wasInterrupted())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// InitializeGlobalSlotsOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseLSquare())
|
||||||
|
return failure();
|
||||||
|
SmallVector<Attribute> slotSymNames;
|
||||||
|
while (!succeeded(parser.parseOptionalRSquare())) {
|
||||||
|
NamedAttrList dummy;
|
||||||
|
StringAttr slotSymName;
|
||||||
|
if (parser.parseSymbolName(slotSymName, "dummy", dummy))
|
||||||
|
return failure();
|
||||||
|
slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName));
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
OpAsmParser::UnresolvedOperand initialValue;
|
||||||
|
if (parser.parseOperand(initialValue))
|
||||||
|
return failure();
|
||||||
|
Type initialValueType;
|
||||||
|
if (parser.parseColonType(initialValueType))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
if (parser.resolveOperand(initialValue, initialValueType, result.operands))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
result.addAttribute("slotSymNames",
|
||||||
|
ArrayAttr::get(parser.getContext(), slotSymNames));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) {
|
||||||
|
p.printOptionalAttrDict(getOperation()->getAttrs(),
|
||||||
|
/*elidedAttrs=*/{"slotSymNames"});
|
||||||
|
p << " [";
|
||||||
|
p.printNewline();
|
||||||
|
for (int i = 0, e = getNumOperands(); i < e; ++i) {
|
||||||
|
p << " " << slotSymNames()[i] << "(" << initialValues()[i] << " : "
|
||||||
|
<< initialValues()[i].getType() << ")";
|
||||||
|
p.printNewline();
|
||||||
|
}
|
||||||
|
p << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult InitializeGlobalSlotsOp::verify() {
|
||||||
|
if (initialValues().size() != slotSymNames().size())
|
||||||
|
return emitOpError("expected number of operands to match number of slots");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
||||||
AdjustCallingConventions.cpp
|
AdjustCallingConventions.cpp
|
||||||
DecomposeComplexOps.cpp
|
DecomposeComplexOps.cpp
|
||||||
DropShapeCalculations.cpp
|
DropShapeCalculations.cpp
|
||||||
|
EraseModuleInitializer.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
GlobalizeObjectGraph.cpp
|
GlobalizeObjectGraph.cpp
|
||||||
InlineGlobalSlots.cpp
|
InlineGlobalSlots.cpp
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
//===- EraseModuleInitializer.cpp --------------------------------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class EraseModuleInitializerPass
|
||||||
|
: public EraseModuleInitializerBase<EraseModuleInitializerPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) {
|
||||||
|
auto intialize =
|
||||||
|
cast<InitializeGlobalSlotsOp>(op.getBody()->getTerminator());
|
||||||
|
if (intialize.getNumOperands() != 0) {
|
||||||
|
op.emitError("could not erase non-empty module initializer");
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
op.erase();
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
if (walkResult.wasInterrupted()) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::torch::Torch::createEraseModuleInitializerPass() {
|
||||||
|
return std::make_unique<EraseModuleInitializerPass>();
|
||||||
|
}
|
|
@ -48,12 +48,6 @@ static FailureOr<NnModuleOp> findRootNnModule(ModuleOp module) {
|
||||||
return rootNnModule;
|
return rootNnModule;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool hasMeaningfulObjectIdentity(Type type) {
|
|
||||||
return !type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
|
||||||
Torch::StringType, Torch::NoneType,
|
|
||||||
Torch::ValueTensorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Object graph recursive traversal.
|
// Object graph recursive traversal.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -100,6 +94,9 @@ public:
|
||||||
assert(it != slotToGlobalSlot.end() && "didn't create global slot");
|
assert(it != slotToGlobalSlot.end() && "didn't create global slot");
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
llvm::MapVector<StringAttr, Value> &getGlobalSlotInitialValues() {
|
||||||
|
return globalSlotInitialValues;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LogicalResult collectUsedSlots() {
|
LogicalResult collectUsedSlots() {
|
||||||
|
@ -187,8 +184,7 @@ private:
|
||||||
assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end());
|
assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end());
|
||||||
slotToGlobalSlot[slot] = globalSlot;
|
slotToGlobalSlot[slot] = globalSlot;
|
||||||
slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.isPrivate()};
|
slotLinkageInfo[slot] = LinkageInfo{linkageName, attr.isPrivate()};
|
||||||
if (failed(populateGlobalSlotInitializer(globalSlot, slot)))
|
globalSlotInitialValues[globalSlot.sym_nameAttr()] = slot.value();
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
nameStack.pop_back();
|
nameStack.pop_back();
|
||||||
}
|
}
|
||||||
|
@ -201,44 +197,6 @@ private:
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
LogicalResult populateGlobalSlotInitializer(GlobalSlotOp globalSlot,
|
|
||||||
SlotOp slot) {
|
|
||||||
OpBuilder builder(globalSlot.getContext());
|
|
||||||
builder.createBlock(&globalSlot.getRegion());
|
|
||||||
|
|
||||||
SmallPtrSet<Operation *, 6> needToClone;
|
|
||||||
Value initialValue = slot.value();
|
|
||||||
SmallVector<Operation *> worklist = {initialValue.getDefiningOp()};
|
|
||||||
while (!worklist.empty()) {
|
|
||||||
Operation *op = worklist.pop_back_val();
|
|
||||||
if (!needToClone.insert(op).second)
|
|
||||||
continue;
|
|
||||||
for (Value operand : op->getOperands()) {
|
|
||||||
if (auto def = operand.getDefiningOp())
|
|
||||||
worklist.push_back(def);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
worklist.assign(needToClone.begin(), needToClone.end());
|
|
||||||
llvm::sort(worklist, [](Operation *lhs, Operation *rhs) {
|
|
||||||
return lhs->isBeforeInBlock(rhs);
|
|
||||||
});
|
|
||||||
BlockAndValueMapping mapping;
|
|
||||||
for (Operation *op : worklist) {
|
|
||||||
builder.clone(*op, mapping);
|
|
||||||
for (Value result : op->getResults()) {
|
|
||||||
if (!hasMeaningfulObjectIdentity(result.getType()))
|
|
||||||
continue;
|
|
||||||
if (!objectsWithIdentityAlreadyCopiedIntoInitializers.insert(result)
|
|
||||||
.second) {
|
|
||||||
return op->emitError() << "potentially-aliased value used to "
|
|
||||||
"initialize multiple slots";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
builder.create<GlobalSlotInitOp>(globalSlot->getLoc(),
|
|
||||||
mapping.lookup(initialValue));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
// Builder for creating GlobalSlotOp's in the module.
|
// Builder for creating GlobalSlotOp's in the module.
|
||||||
OpBuilder globalSlotBuilder;
|
OpBuilder globalSlotBuilder;
|
||||||
// Symbol table for the module.
|
// Symbol table for the module.
|
||||||
|
@ -262,16 +220,50 @@ private:
|
||||||
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
|
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
|
||||||
// The corresponding GlobalSlotOp for each SlotOp in the program.
|
// The corresponding GlobalSlotOp for each SlotOp in the program.
|
||||||
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
|
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
|
||||||
// A set of values that we have copied into torch.global_slot initializers,
|
// The initializing value for each GlobalSlotOp.
|
||||||
// which cannot be used in multiple initializers because their object
|
// This is a MapVector to keep the order deterministic.
|
||||||
// identity is important.
|
llvm::MapVector<StringAttr, Value> globalSlotInitialValues;
|
||||||
DenseSet<Value> objectsWithIdentityAlreadyCopiedIntoInitializers;
|
|
||||||
// Used to keep track of all the used torch slots so that the restrictions can
|
// Used to keep track of all the used torch slots so that the restrictions can
|
||||||
// be applied to those slots only.
|
// be applied to those slots only.
|
||||||
DenseSet<SlotOp> usedSlots;
|
DenseSet<SlotOp> usedSlots;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable,
|
||||||
|
ObjectGraphInfo &objectGraphInfo) {
|
||||||
|
auto builder = OpBuilder::atBlockBegin(module.getBody());
|
||||||
|
auto moduleInitializer =
|
||||||
|
builder.create<GlobalSlotModuleInitializerOp>(module.getLoc());
|
||||||
|
Block *body = builder.createBlock(&moduleInitializer.initializer());
|
||||||
|
builder.setInsertionPointToEnd(body);
|
||||||
|
SmallVector<Operation *> opsToMove;
|
||||||
|
for (Operation &op : *module.getBody()) {
|
||||||
|
if (isa<ClassTypeOp, NnModuleOp, GlobalSlotOp, func::FuncOp,
|
||||||
|
GlobalSlotModuleInitializerOp>(op))
|
||||||
|
continue;
|
||||||
|
opsToMove.push_back(&op);
|
||||||
|
}
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
for (Operation *op : opsToMove) {
|
||||||
|
// The ops are used by `torch.slot` ops in the enclosing module.
|
||||||
|
// Cloning avoids needing to handle those uses specially.
|
||||||
|
builder.clone(*op, mapping);
|
||||||
|
}
|
||||||
|
SmallVector<Attribute> slotSymNames;
|
||||||
|
SmallVector<Value> initialValues;
|
||||||
|
for (auto &kv : objectGraphInfo.getGlobalSlotInitialValues()) {
|
||||||
|
StringAttr symName = kv.first;
|
||||||
|
Value initializer = kv.second;
|
||||||
|
slotSymNames.push_back(FlatSymbolRefAttr::get(symName));
|
||||||
|
initialValues.push_back(mapping.lookup(initializer));
|
||||||
|
}
|
||||||
|
builder.create<InitializeGlobalSlotsOp>(
|
||||||
|
moduleInitializer.getLoc(),
|
||||||
|
ArrayAttr::get(module.getContext(), slotSymNames), initialValues);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Monomorphization.
|
// Monomorphization.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -596,7 +588,13 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
instances[classType].push_back(nnModule);
|
instances[classType].push_back(nnModule);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 2: Verify all functions are suitable to be analyzed by our later code.
|
// Step 2: Create the torch.global_slot.module_initializer op.
|
||||||
|
|
||||||
|
if (failed(createGlobalSlotModuleInitializer(module, symbolTable,
|
||||||
|
objectGraphInfo)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Step 3: Verify all functions are suitable to be analyzed by our later code.
|
||||||
// This eliminates special handling / error code later.
|
// This eliminates special handling / error code later.
|
||||||
//
|
//
|
||||||
// This is important, because in principle, we can perform arbitrarily complex
|
// This is important, because in principle, we can perform arbitrarily complex
|
||||||
|
@ -608,7 +606,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 3: Calculate the set of monomorphized functions that need to be
|
// Step 4: Calculate the set of monomorphized functions that need to be
|
||||||
// created. For each call that passes !torch.nn.Module to a function, we need
|
// created. For each call that passes !torch.nn.Module to a function, we need
|
||||||
// to create a specialized version of that function just for that instance (or
|
// to create a specialized version of that function just for that instance (or
|
||||||
// combination of instances in the case of multiple arguments).
|
// combination of instances in the case of multiple arguments).
|
||||||
|
@ -633,7 +631,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: Clone/rewrite functions to implement the necessary
|
// Step 5: Clone/rewrite functions to implement the necessary
|
||||||
// monomorphizations.
|
// monomorphizations.
|
||||||
DenseMap<Monomorphization, func::FuncOp> newFuncs;
|
DenseMap<Monomorphization, func::FuncOp> newFuncs;
|
||||||
int uniquifier = 0;
|
int uniquifier = 0;
|
||||||
|
@ -672,13 +670,13 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 5: Clean up object graph.
|
// Step 6: Clean up object graph.
|
||||||
DenseSet<func::FuncOp> liveFuncs;
|
DenseSet<func::FuncOp> liveFuncs;
|
||||||
for (auto &kv : newFuncs) {
|
for (auto &kv : newFuncs) {
|
||||||
liveFuncs.insert(kv.second);
|
liveFuncs.insert(kv.second);
|
||||||
}
|
}
|
||||||
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
|
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
|
||||||
if (isa<GlobalSlotOp>(&op))
|
if (isa<GlobalSlotOp, GlobalSlotModuleInitializerOp>(&op))
|
||||||
continue;
|
continue;
|
||||||
if (auto func = dyn_cast<func::FuncOp>(op)) {
|
if (auto func = dyn_cast<func::FuncOp>(op)) {
|
||||||
if (liveFuncs.contains(func))
|
if (liveFuncs.contains(func))
|
||||||
|
|
|
@ -6,83 +6,429 @@
|
||||||
// Also available under a BSD-style license. See LICENSE.
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements an optimistic dataflow analysis that proves that values
|
||||||
|
// used in global slot initializers are "safe" (see definition below). This
|
||||||
|
// analysis allows us to inline global slot initializers.
|
||||||
|
//
|
||||||
|
// One thing to note is that this inlining (as with all inlining) can create
|
||||||
|
// duplicate ops. That is usually not a problem, except for certain large
|
||||||
|
// tensor literals. We rely on later CSE passes to deduplicate those literals.
|
||||||
|
//
|
||||||
|
// For debugging this pass an effort has been made for
|
||||||
|
// `-debug-only=dataflow` and `-debug-only=torch-inline-global-slots` to give a
|
||||||
|
// good experience. When debugging this pass, it is recommended to start with
|
||||||
|
// `-debug-only=torch-inline-global-slots` to find values that are marked
|
||||||
|
// unsafe unexpectedly and then `-debug-only=dataflow` to find why.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Analysis/DataFlowFramework.h"
|
||||||
|
#include "mlir/Analysis/SliceAnalysis.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#define DEBUG_TYPE "torch-inline-global-slots"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
/// A program point representing a symbol.
|
||||||
|
///
|
||||||
|
/// In principle we could use the `Operation *` program point of the Symbol op,
|
||||||
|
/// but that just adds a layer of indirection through a symbol table for the
|
||||||
|
/// purpose of this analysis.
|
||||||
|
///
|
||||||
|
/// This is easier because we only support FlatSymbolRefAttr's in Torch-MLIR in
|
||||||
|
/// a single module. If we had to support complex nested symbol references, we
|
||||||
|
/// would probably want to go through the effort to indirect through the symbol
|
||||||
|
/// tables to make things clearer.
|
||||||
|
class FlatSymbolRefProgramPoint
|
||||||
|
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
|
||||||
|
FlatSymbolRefAttr> {
|
||||||
|
public:
|
||||||
|
using Base::Base;
|
||||||
|
void print(raw_ostream &os) const override {
|
||||||
|
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
|
||||||
|
}
|
||||||
|
Location getLoc() const override {
|
||||||
|
return UnknownLoc::get(getValue().getContext());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool isTypeTriviallySafe(Type type) {
|
||||||
|
return type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||||
|
Torch::StringType, Torch::NoneType, Torch::ValueTensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
||||||
|
Operation *op = use.getOwner();
|
||||||
|
// If the op unconditionally has value semantics, then the use has value
|
||||||
|
// semantics.
|
||||||
|
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>())
|
||||||
|
return true;
|
||||||
|
// The condition of the torch.prim.if op is treated with value semantics.
|
||||||
|
if (isa<PrimIfOp>(op) && use.getOperandNumber() == 0)
|
||||||
|
return true;
|
||||||
|
// TODO: Generalize the HasValueSemantics trait to support
|
||||||
|
// operand/result-granularity.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// State tracking if an IR construct is "safe".
|
||||||
|
///
|
||||||
|
/// This state is tracked on Value's and also on global slots (via a
|
||||||
|
/// FlatSymbolRefProgramPoint).
|
||||||
|
///
|
||||||
|
/// In this context, "safe" means that the object is safe to inline.
|
||||||
|
/// This covers a few concepts
|
||||||
|
/// - the value cannot be mutated by the program
|
||||||
|
/// - the value cannot be potentially aliased, with that alias itself being
|
||||||
|
/// unsafe
|
||||||
|
class InlineGlobalSlotsAnalysisState : public AnalysisState {
|
||||||
|
public:
|
||||||
|
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {}
|
||||||
|
|
||||||
|
bool isUninitialized() const override {
|
||||||
|
// We are an optimistic analysis, so we are always default initialized to
|
||||||
|
// the optimistic "assumed safe" state.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ChangeResult defaultInitialize() override {
|
||||||
|
// We are an optimistic analysis, so the default state is always "safe".
|
||||||
|
return setSafe();
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(raw_ostream &os) const override {
|
||||||
|
os << "InlineGlobalSlotsAnalysisState(" << (isSafe ? "safe" : "unsafe")
|
||||||
|
<< ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper for setting the state with the correct ChangeResult.
|
||||||
|
ChangeResult setSafe(bool newIsSafe = true) {
|
||||||
|
// As an optimistic analysis, once we prove that a value is unsafe, nothing
|
||||||
|
// can prove that it is safe again. This is the monotonicity property of
|
||||||
|
// the dataflow analysis that guarantees that we reach a fixed-point.
|
||||||
|
// If that property doesn't hold, then there is a bug in the analysis.
|
||||||
|
assert(!(isSafe == false && newIsSafe == true) && "non-monotonic update");
|
||||||
|
if (isSafe == newIsSafe)
|
||||||
|
return ChangeResult::NoChange;
|
||||||
|
isSafe = newIsSafe;
|
||||||
|
return ChangeResult::Change;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper for updatating the state with the correct ChangeResult based on the
|
||||||
|
/// safety of a use.
|
||||||
|
ChangeResult
|
||||||
|
incorporateSafetyOfUse(const InlineGlobalSlotsAnalysisState *useState) {
|
||||||
|
// The use is safe, so no need to change anything.
|
||||||
|
if (useState->isSafe)
|
||||||
|
return ChangeResult::NoChange;
|
||||||
|
return setSafe(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is an optimistic analysis. We start assuming everything is safe.
|
||||||
|
bool isSafe = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {
|
||||||
|
public:
|
||||||
|
InlineGlobalSlotsAnalysis(DataFlowSolver &solver);
|
||||||
|
LogicalResult initialize(Operation *top) override;
|
||||||
|
LogicalResult visit(ProgramPoint point) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// The local transfer function determining the safety of `value`.
|
||||||
|
bool isValueSafeTransferFunction(Value value);
|
||||||
|
/// The InitializeGlobalSlotsOp of the current module we are analyzing.
|
||||||
|
///
|
||||||
|
/// This is used to propagate the analysis from globals into to the module
|
||||||
|
/// initializer.
|
||||||
|
InitializeGlobalSlotsOp initializeGlobalSlotsOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
|
||||||
|
: DataFlowAnalysis(solver) {
|
||||||
|
registerPointKind<FlatSymbolRefProgramPoint>();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
||||||
|
auto walkResult = top->walk([this](Operation *op) {
|
||||||
|
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
||||||
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||||
|
getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
|
FlatSymbolRefAttr::get(globalSlot.sym_nameAttr())));
|
||||||
|
propagateIfChanged(state,
|
||||||
|
state->setSafe(globalSlot.getVisibility() !=
|
||||||
|
SymbolTable::Visibility::Public));
|
||||||
|
}
|
||||||
|
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
|
||||||
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||||
|
getProgramPoint<FlatSymbolRefProgramPoint>(globalSlotSet.slotAttr()));
|
||||||
|
propagateIfChanged(state, state->setSafe(false));
|
||||||
|
}
|
||||||
|
// Save the InitializeGlobalSlotsOp for later referencee
|
||||||
|
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||||
|
initializeGlobalSlotsOp = initialize;
|
||||||
|
}
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
if (failed(visit(result)))
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
if (walkResult.wasInterrupted())
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
||||||
|
if (Value value = point.dyn_cast<Value>()) {
|
||||||
|
bool isSafe = isValueSafeTransferFunction(value);
|
||||||
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||||
|
propagateIfChanged(state, state->setSafe(isSafe));
|
||||||
|
|
||||||
|
// Handle GlobalSlotGetOp's.
|
||||||
|
if (auto opResult = value.dyn_cast<OpResult>()) {
|
||||||
|
if (auto globalSlotGet =
|
||||||
|
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
||||||
|
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
|
globalSlotGet.slotAttr());
|
||||||
|
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||||
|
flatSymbolRefPoint, globalSlotGet.result());
|
||||||
|
auto *globalState =
|
||||||
|
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
|
||||||
|
propagateIfChanged(globalState,
|
||||||
|
globalState->incorporateSafetyOfUse(valueState));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (auto *genericProgramPoint = point.dyn_cast<GenericProgramPoint *>()) {
|
||||||
|
if (auto *flatSymbolRefPoint =
|
||||||
|
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
|
||||||
|
if (initializeGlobalSlotsOp) {
|
||||||
|
auto it =
|
||||||
|
llvm::find(initializeGlobalSlotsOp.slotSymNames(),
|
||||||
|
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
|
||||||
|
Value value = initializeGlobalSlotsOp->getOperand(
|
||||||
|
std::distance(initializeGlobalSlotsOp.slotSymNames().begin(), it));
|
||||||
|
auto *flatSymbolRefState =
|
||||||
|
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
|
||||||
|
flatSymbolRefPoint);
|
||||||
|
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||||
|
propagateIfChanged(valueState,
|
||||||
|
valueState->setSafe(flatSymbolRefState->isSafe));
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LLVM_DEBUG(
|
||||||
|
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is only a member function to access protected get* functions.
|
||||||
|
bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
|
||||||
|
if (isTypeTriviallySafe(value.getType()))
|
||||||
|
return true;
|
||||||
|
for (OpOperand &use : value.getUses()) {
|
||||||
|
Operation *op = use.getOwner();
|
||||||
|
|
||||||
|
if (isUseTreatedWithValueSemantics(use))
|
||||||
|
continue;
|
||||||
|
// If the op is read-only and all results are safe, then this value is
|
||||||
|
// safe. This covers, for example, view-like ops that create aliases.
|
||||||
|
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() ||
|
||||||
|
MemoryEffectOpInterface::hasNoEffect(op)) &&
|
||||||
|
llvm::all_of(op->getResults(), [&](Value result) {
|
||||||
|
auto *state =
|
||||||
|
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
|
||||||
|
return state->isSafe;
|
||||||
|
}))
|
||||||
|
continue;
|
||||||
|
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||||
|
auto symName = initialize.slotSymNames()[use.getOperandNumber()]
|
||||||
|
.cast<FlatSymbolRefAttr>();
|
||||||
|
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||||
|
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
|
||||||
|
if (state->isSafe)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// We may not create all the dependency edges, but that is ok since at
|
||||||
|
// this point we have already reached the fixed-point.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation *> getBackwardSliceIncludingRoot(Value initialValue) {
|
||||||
|
SetVector<Operation *> sliceSet;
|
||||||
|
getBackwardSlice(initialValue, &sliceSet);
|
||||||
|
SmallVector<Operation *> slice;
|
||||||
|
llvm::append_range(slice, sliceSet);
|
||||||
|
slice.push_back(initialValue.getDefiningOp());
|
||||||
|
return slice;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isInitialValueTransitivelySafeToInline(Value initialValue,
|
||||||
|
DataFlowSolver &solver) {
|
||||||
|
SmallVector<Operation *> slice = getBackwardSliceIncludingRoot(initialValue);
|
||||||
|
for (Operation *op : slice) {
|
||||||
|
for (auto result : op->getResults()) {
|
||||||
|
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(result);
|
||||||
|
if (!state->isSafe) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class InlineGlobalSlotsPass
|
class InlineGlobalSlotsPass
|
||||||
: public InlineGlobalSlotsBase<InlineGlobalSlotsPass> {
|
: public InlineGlobalSlotsBase<InlineGlobalSlotsPass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
SymbolTable symbolTable(module);
|
DataFlowSolver solver;
|
||||||
auto uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
|
solver.load<InlineGlobalSlotsAnalysis>();
|
||||||
if (!uses) {
|
if (failed(solver.initializeAndRun(module)))
|
||||||
module.emitError() << "cannot analyze symbol uses";
|
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
||||||
|
LLVM_DEBUG({
|
||||||
|
module->walk([&](Operation *op) {
|
||||||
|
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
||||||
|
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
|
||||||
|
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
|
FlatSymbolRefAttr::get(globalSlot.sym_nameAttr())));
|
||||||
|
state->print(llvm::dbgs());
|
||||||
|
llvm::dbgs() << ": "
|
||||||
|
<< FlatSymbolRefAttr::get(globalSlot.sym_nameAttr())
|
||||||
|
<< "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op->getNumResults() != 1)
|
||||||
|
return;
|
||||||
|
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
|
||||||
|
op->getResult(0));
|
||||||
|
state->print(llvm::dbgs());
|
||||||
|
llvm::dbgs() << ": ";
|
||||||
|
op->dump();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
Torch::InitializeGlobalSlotsOp initialize;
|
||||||
|
// TODO: Have a torch.module with an optional initializer region to make
|
||||||
|
// this tighter.
|
||||||
|
for (auto moduleInitializer :
|
||||||
|
module.getOps<Torch::GlobalSlotModuleInitializerOp>()) {
|
||||||
|
initialize = cast<Torch::InitializeGlobalSlotsOp>(
|
||||||
|
moduleInitializer.getBody()->getTerminator());
|
||||||
}
|
}
|
||||||
// Find all the global slots potentially written from within the module.
|
if (!initialize) {
|
||||||
// (we handle the case of non-private symbols later).
|
return;
|
||||||
DenseSet<Torch::GlobalSlotOp> potentiallyWrittenGlobalSlots;
|
}
|
||||||
for (const SymbolTable::SymbolUse &use : *uses) {
|
|
||||||
auto flatSymbolRef = use.getSymbolRef().dyn_cast<FlatSymbolRefAttr>();
|
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
|
||||||
if (!flatSymbolRef) {
|
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
|
||||||
use.getUser()->emitError() << "unimplemented: nested SymbolRef's";
|
auto slotSymName = initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||||
return signalPassFailure();
|
Value operand = initialize.getOperand(i);
|
||||||
|
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
|
initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>());
|
||||||
|
auto *state =
|
||||||
|
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
||||||
|
// We roll the analysis of whether a slot is set or public into the
|
||||||
|
// main dataflow analysis, so we need to check the slot's
|
||||||
|
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
|
||||||
|
// For example, a public !torch.int is not safe to inline, even though
|
||||||
|
// it is a value-semantic type and so the actual initializer value
|
||||||
|
// itself is conceptually safe to inline.
|
||||||
|
if (!state->isSafe) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Check to see if the initializing value is safe to inline.
|
||||||
|
// This requires a transitive check of all subobjects.
|
||||||
|
// TODO: This would really be more logical to do as a forward dataflow
|
||||||
|
// analyis on the whole module initializer rather than doing the
|
||||||
|
// transitive check backward for each initial value. But it is just
|
||||||
|
// too much boilerplate to write that with the dataflow framework and we
|
||||||
|
// generally don't expect long transitive chains of values here -- most
|
||||||
|
// initial values are just single tensor literals.
|
||||||
|
if (isInitialValueTransitivelySafeToInline(operand, solver)) {
|
||||||
|
safeToInline.insert(slotSymName);
|
||||||
}
|
}
|
||||||
auto globalSlot =
|
|
||||||
symbolTable.lookup<Torch::GlobalSlotOp>(flatSymbolRef.getValue());
|
|
||||||
|
|
||||||
if (!globalSlot)
|
|
||||||
continue;
|
|
||||||
if (isa<Torch::GlobalSlotGetOp>(use.getUser()))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
potentiallyWrittenGlobalSlots.insert(globalSlot);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SymbolTable symbolTable(module);
|
||||||
DenseSet<Operation *> toErase;
|
DenseSet<Operation *> toErase;
|
||||||
// Inline all the global slots that are not potentially written.
|
module.walk([&](Torch::GlobalSlotGetOp op) {
|
||||||
for (const SymbolTable::SymbolUse &use : *uses) {
|
if (!safeToInline.count(op.slotAttr()))
|
||||||
auto flatSymbolRef = use.getSymbolRef().cast<FlatSymbolRefAttr>();
|
return;
|
||||||
auto globalSlot =
|
// TODO: Make this more ergonomic.
|
||||||
symbolTable.lookup<Torch::GlobalSlotOp>(flatSymbolRef.getValue());
|
auto it = llvm::find(initialize.slotSymNames(), op.slotAttr());
|
||||||
if (!globalSlot)
|
Value initialValue = initialize.getOperand(
|
||||||
continue;
|
std::distance(initialize.slotSymNames().begin(), it));
|
||||||
// And external user might write to the global slot.
|
// It seems inefficient to get a backward slice again here, but we are
|
||||||
if (!globalSlot.isPrivate())
|
// going to be cloning the whole slice anyway, so it doesn't seem like a
|
||||||
continue;
|
// big deal.
|
||||||
// An internal user exists which might write to the global slot.
|
SmallVector<Operation *> slice =
|
||||||
if (potentiallyWrittenGlobalSlots.contains(globalSlot))
|
getBackwardSliceIncludingRoot(initialValue);
|
||||||
continue;
|
BlockAndValueMapping mapping;
|
||||||
auto globalSlotGet = cast<Torch::GlobalSlotGetOp>(use.getUser());
|
OpBuilder builder(op);
|
||||||
OpBuilder builder(globalSlotGet);
|
for (Operation *opInSlice : slice)
|
||||||
BlockAndValueMapping mapper;
|
builder.clone(*opInSlice, mapping);
|
||||||
for (Operation &op : globalSlot.getBody()->without_terminator())
|
auto inlinedInitialValue = mapping.lookup(initialValue);
|
||||||
builder.clone(op, mapper);
|
inlinedInitialValue = Torch::adjustStaticInformation(
|
||||||
Value cloned = mapper.lookup(
|
builder, op.getLoc(), inlinedInitialValue, op.getType(),
|
||||||
cast<GlobalSlotInitOp>(globalSlot.getBody()->getTerminator())
|
/*userAllowsRefinement=*/false);
|
||||||
.getOperand());
|
op.replaceAllUsesWith(inlinedInitialValue);
|
||||||
globalSlotGet.replaceAllUsesWith(cloned);
|
toErase.insert(op);
|
||||||
toErase.insert(globalSlotGet);
|
});
|
||||||
toErase.insert(globalSlot);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Clean up after the transform.
|
||||||
|
|
||||||
|
// Erase any pending ops.
|
||||||
for (Operation *op : toErase)
|
for (Operation *op : toErase)
|
||||||
op->erase();
|
op->erase();
|
||||||
|
// Erase any global slots that we inlined.
|
||||||
|
// This could be left to SymbolDCE but it's not hard to do here.
|
||||||
|
for (FlatSymbolRefAttr symName :
|
||||||
|
llvm::map_range(safeToInline, [](Attribute attr) {
|
||||||
|
return attr.cast<FlatSymbolRefAttr>();
|
||||||
|
})) {
|
||||||
|
auto globalSlot =
|
||||||
|
symbolTable.lookup<Torch::GlobalSlotOp>(symName.getValue());
|
||||||
|
globalSlot.erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the initializer.
|
||||||
|
SmallVector<Attribute> newSlotSymNames;
|
||||||
|
SmallVector<Value> newInitialValues;
|
||||||
|
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
|
||||||
|
auto slotSymName = initialize.slotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||||
|
if (!safeToInline.count(slotSymName)) {
|
||||||
|
newSlotSymNames.push_back(slotSymName);
|
||||||
|
newInitialValues.push_back(initialize.getOperand(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
OpBuilder builder(initialize);
|
||||||
|
builder.create<Torch::InitializeGlobalSlotsOp>(
|
||||||
|
initialize.getLoc(),
|
||||||
|
ArrayAttr::get(module.getContext(), newSlotSymNames),
|
||||||
|
newInitialValues);
|
||||||
|
}
|
||||||
|
initialize.erase();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -91,6 +91,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// Incorporate user annotations and remove signature Python-isms.
|
// Incorporate user annotations and remove signature Python-isms.
|
||||||
pm.addPass(createAdjustCallingConventionsPass());
|
pm.addPass(createAdjustCallingConventionsPass());
|
||||||
|
|
||||||
|
// TODO: Remove options.optimize and this OPT-ONLY stuff -- we are already way
|
||||||
|
// past the point of no return for it being necessary for functional
|
||||||
|
// correctness.
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Eliminate the PrimTupleIndexOp generated from the
|
// Eliminate the PrimTupleIndexOp generated from the
|
||||||
// adjustCallingConventions
|
// adjustCallingConventions
|
||||||
|
@ -102,6 +105,22 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// Also don't rely on this pass to expose constants into the program to
|
// Also don't rely on this pass to expose constants into the program to
|
||||||
// simplify handling of "optional".
|
// simplify handling of "optional".
|
||||||
pm.addPass(createInlineGlobalSlotsPass());
|
pm.addPass(createInlineGlobalSlotsPass());
|
||||||
|
// After doing a first round of inlining global slots, canonicalize again to
|
||||||
|
// take advantage of optimization opportunities exposed by the inlined
|
||||||
|
// global slots. In particular, this is functionally necessary now because
|
||||||
|
// large amounts of control flow are guarded by an "is training" flag, so
|
||||||
|
// inlining removes certain mutating operations done on the slots enabling
|
||||||
|
// them to be deleted.
|
||||||
|
// TODO: In full generality, we need to do a fixed-point iteration of
|
||||||
|
// shape inference, maximizing value semantics, decomposition, inling global
|
||||||
|
// slots, and canonicalization.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
// Inline again, cleaning up any remaining global slots that might be dead
|
||||||
|
// now.
|
||||||
|
pm.addPass(createInlineGlobalSlotsPass());
|
||||||
|
// Erase the module initializers (or fail compilation), since they aren't
|
||||||
|
// permitted in our backend contract at the moment.
|
||||||
|
pm.addPass(Torch::createEraseModuleInitializerPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduce variants of ops to a smaller set of primitives.
|
// Reduce variants of ops to a smaller set of primitives.
|
||||||
|
|
|
@ -42,6 +42,14 @@ class VerifyInvariantsBeforeBackendLoweringPass
|
||||||
: public VerifyInvariantsBeforeBackendLoweringBase<
|
: public VerifyInvariantsBeforeBackendLoweringBase<
|
||||||
VerifyInvariantsBeforeBackendLoweringPass> {
|
VerifyInvariantsBeforeBackendLoweringPass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
if (getOperation()
|
||||||
|
.walk([](Torch::GlobalSlotModuleInitializerOp op) {
|
||||||
|
op.emitError()
|
||||||
|
<< "unsupported by backend lowering: module initializers";
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
})
|
||||||
|
.wasInterrupted())
|
||||||
|
return signalPassFailure();
|
||||||
auto walkResult = getOperation().walk([&](Block *block) {
|
auto walkResult = getOperation().walk([&](Block *block) {
|
||||||
// Check invariants on all the Value's in the program.
|
// Check invariants on all the Value's in the program.
|
||||||
// That is, check all BlockArgument's and OpResult's.
|
// That is, check all BlockArgument's and OpResult's.
|
||||||
|
|
|
@ -2,26 +2,22 @@
|
||||||
|
|
||||||
// Basic case.
|
// Basic case.
|
||||||
|
|
||||||
// CHECK-LABEL: torch.global_slot @b : !torch.bool {
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
// CHECK: %[[INIT:.*]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: torch.global_slot.init %[[INIT]] : !torch.bool
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[FLOAT4:.*]] = torch.constant.float 4.250000e+01
|
||||||
|
// CHECK: %[[TENSOR:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK: @b(%[[TRUE]] : !torch.bool)
|
||||||
|
// CHECK: @i(%[[INT3]] : !torch.int)
|
||||||
|
// CHECK: @f(%[[FLOAT4]] : !torch.float)
|
||||||
|
// CHECK: @t(%[[TENSOR]] : !torch.tensor)
|
||||||
|
// CHECK: ]
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
// CHECK-LABEL: torch.global_slot @b : !torch.bool
|
||||||
// CHECK-LABEL: torch.global_slot @i : !torch.int {
|
// CHECK-LABEL: torch.global_slot @i : !torch.int
|
||||||
// CHECK: %[[INIT:.*]] = torch.constant.int 3
|
// CHECK-LABEL: torch.global_slot @f : !torch.float
|
||||||
// CHECK: torch.global_slot.init %[[INIT]] : !torch.int
|
// CHECK-LABEL: torch.global_slot @t : !torch.tensor
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// CHECK-LABEL: torch.global_slot @f : !torch.float {
|
|
||||||
// CHECK: %[[INIT:.*]] = torch.constant.float 4.250000e+01
|
|
||||||
// CHECK: torch.global_slot.init %[[INIT]] : !torch.float
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// CHECK-LABEL: torch.global_slot @t : !torch.tensor {
|
|
||||||
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
|
||||||
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
torch.attr "b" : !torch.bool
|
torch.attr "b" : !torch.bool
|
||||||
torch.attr "i" : !torch.int
|
torch.attr "i" : !torch.int
|
||||||
|
|
|
@ -35,43 +35,3 @@ func.func private @ensure_all_slots_are_used(%arg0: !torch.nn.Module<"parent">,
|
||||||
%2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float
|
%2 = torch.prim.GetAttr %arg1["float"] : !torch.nn.Module<"child"> -> !torch.float
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
torch.class_type @c {
|
|
||||||
torch.attr "t1" : !torch.tensor
|
|
||||||
torch.attr "t2" : !torch.tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
|
|
||||||
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
|
||||||
torch.nn_module {
|
|
||||||
torch.slot "t1", %t : !torch.tensor
|
|
||||||
torch.slot "t2", %t : !torch.tensor
|
|
||||||
} : !torch.nn.Module<"c">
|
|
||||||
func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
|
|
||||||
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
|
|
||||||
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
|
|
||||||
%cst = torch.constant.int 1
|
|
||||||
%ret = torch.aten.add.Tensor %t1, %t2, %cst : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor
|
|
||||||
return %ret : !torch.tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
torch.class_type @c {
|
|
||||||
torch.attr "t1" : !torch.tensor
|
|
||||||
torch.attr "t2" : !torch.tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
|
|
||||||
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
|
||||||
torch.nn_module {
|
|
||||||
torch.slot "t1", %t : !torch.tensor
|
|
||||||
torch.slot "t2", %t : !torch.tensor
|
|
||||||
} : !torch.nn.Module<"c">
|
|
||||||
func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
|
|
||||||
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
|
||||||
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK that multiple nested initialization ops are properly handled.
|
// Check that multiple nested initialization ops are properly handled.
|
||||||
|
|
||||||
// CHECK-LABEL: torch.global_slot @l : !torch.list<list<list<tensor>>> {
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<tensor>
|
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<tensor>
|
||||||
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<tensor>, !torch.list<tensor>) -> !torch.list<list<tensor>>
|
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<tensor>, !torch.list<tensor>) -> !torch.list<list<tensor>>
|
||||||
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<list<tensor>>, !torch.list<list<tensor>>) -> !torch.list<list<list<tensor>>>
|
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<list<tensor>>, !torch.list<list<tensor>>) -> !torch.list<list<list<tensor>>>
|
||||||
// CHECK: torch.global_slot.init %[[L2]] : !torch.list<list<list<tensor>>>
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK: @l(%[[L2]] : !torch.list<list<list<tensor>>>)
|
||||||
|
// CHECK: ]
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
// CHECK-LABEL: torch.global_slot @l : !torch.list<list<list<tensor>>>
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
torch.attr "l" : !torch.list<list<list<tensor>>>
|
torch.attr "l" : !torch.list<list<list<tensor>>>
|
||||||
|
|
|
@ -12,20 +12,22 @@ torch.class_type @__torch__.Submodule {
|
||||||
torch.method private "forward", @__torch__.Submodule.forward
|
torch.method private "forward", @__torch__.Submodule.forward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK: @s1.n(%[[INT1]] : !torch.int)
|
||||||
|
// CHECK: @s2.n(%[[INT2]] : !torch.int)
|
||||||
|
// CHECK: ]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int
|
||||||
|
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%s1 = torch.nn_module {
|
%s1 = torch.nn_module {
|
||||||
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int {
|
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
|
||||||
// CHECK: torch.global_slot.init %[[C1]] : !torch.int
|
|
||||||
// CHECK: }
|
|
||||||
torch.slot "n", %int1 : !torch.int
|
torch.slot "n", %int1 : !torch.int
|
||||||
} : !torch.nn.Module<"__torch__.Submodule">
|
} : !torch.nn.Module<"__torch__.Submodule">
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
%s2 = torch.nn_module {
|
%s2 = torch.nn_module {
|
||||||
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int {
|
|
||||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
|
||||||
// CHECK: torch.global_slot.init %[[C2]] : !torch.int
|
|
||||||
// CHECK: }
|
|
||||||
torch.slot "n", %int2 : !torch.int
|
torch.slot "n", %int2 : !torch.int
|
||||||
} : !torch.nn.Module<"__torch__.Submodule">
|
} : !torch.nn.Module<"__torch__.Submodule">
|
||||||
%3 = torch.nn_module {
|
%3 = torch.nn_module {
|
||||||
|
|
|
@ -10,20 +10,23 @@ torch.class_type @__torch__.Submodule {
|
||||||
torch.method private "forward", @__torch__.Submodule.forward
|
torch.method private "forward", @__torch__.Submodule.forward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK: @s1.n(%[[INT1]] : !torch.int)
|
||||||
|
// CHECK: @s2.n(%[[INT2]] : !torch.int)
|
||||||
|
// CHECK: ]
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int
|
||||||
|
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int
|
||||||
|
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%s1 = torch.nn_module {
|
%s1 = torch.nn_module {
|
||||||
// CHECK-LABEL: torch.global_slot "private" @s1.n : !torch.int {
|
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
|
||||||
// CHECK: torch.global_slot.init %[[C1]] : !torch.int
|
|
||||||
// CHECK: }
|
|
||||||
torch.slot "n", %int1 : !torch.int
|
torch.slot "n", %int1 : !torch.int
|
||||||
} : !torch.nn.Module<"__torch__.Submodule">
|
} : !torch.nn.Module<"__torch__.Submodule">
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
%s2 = torch.nn_module {
|
%s2 = torch.nn_module {
|
||||||
// CHECK-LABEL: torch.global_slot "private" @s2.n : !torch.int {
|
|
||||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
|
||||||
// CHECK: torch.global_slot.init %[[C2]] : !torch.int
|
|
||||||
// CHECK: }
|
|
||||||
torch.slot "n", %int2 : !torch.int
|
torch.slot "n", %int2 : !torch.int
|
||||||
} : !torch.nn.Module<"__torch__.Submodule">
|
} : !torch.nn.Module<"__torch__.Submodule">
|
||||||
%3 = torch.nn_module {
|
%3 = torch.nn_module {
|
||||||
|
|
|
@ -2,10 +2,13 @@
|
||||||
|
|
||||||
// Check that linkage names consist of the dotted path from the root.
|
// Check that linkage names consist of the dotted path from the root.
|
||||||
|
|
||||||
// CHECK-LABEL: torch.global_slot @m.float : !torch.float {
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
// CHECK: %[[INIT:.*]] = torch.constant.float 4.200000e+01
|
// CHECK: %[[FLOAT:.*]] = torch.constant.float 4.200000e+01
|
||||||
// CHECK: torch.global_slot.init %[[INIT]] : !torch.float
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK: @m.float(%[[FLOAT]] : !torch.float)
|
||||||
|
// CHECK: ]
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
// CHECK-LABEL: torch.global_slot @m.float : !torch.float
|
||||||
|
|
||||||
|
|
||||||
torch.class_type @child {
|
torch.class_type @child {
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-erase-module-initializer -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: module {
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
torch.global_slot @slot0 : !torch.int
|
||||||
|
|
||||||
|
// expected-error@+1 {{could not erase non-empty module initializer}}
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.constant.int 0
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@slot0(%0: !torch.int)
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// Safety analysis aspect of the pass.
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: Public slots cannot be inlined.
|
||||||
|
// Test case: Set slots cannot be inlined.
|
||||||
|
|
||||||
|
// CHECK: torch.global_slot @public : !torch.int
|
||||||
|
// CHECK: torch.global_slot "private" @set : !torch.int
|
||||||
|
torch.global_slot @public : !torch.int
|
||||||
|
torch.global_slot "private" @set : !torch.int
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK: @public(%[[C1]] : !torch.int)
|
||||||
|
// CHECK: @set(%[[C1]] : !torch.int)
|
||||||
|
// CHECK: ]
|
||||||
|
// CHECK: }
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.constant.int 1
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@public(%0 : !torch.int)
|
||||||
|
@set(%0 : !torch.int)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
func.func @forward() {
|
||||||
|
%0 = torch.constant.int 2
|
||||||
|
torch.global_slot.set @set = %0 : !torch.int
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: Propagate safety transitively through ops without HasValueSemantics.
|
||||||
|
torch.global_slot "private" @tensor : !torch.tensor
|
||||||
|
torch.global_slot "private" @list : !torch.list<tensor>
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK-NEXT ]
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
|
||||||
|
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@tensor(%0 : !torch.tensor)
|
||||||
|
@list(%1 : !torch.list<tensor>)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
func.func @forward() {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.global_slot.get @list : !torch.list<tensor>
|
||||||
|
%1 = torch.aten.__getitem__.t %0, %int0 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||||
|
%2 = torch.aten.mul.Tensor %1, %1 : !torch.tensor, !torch.tensor -> !torch.tensor
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
|
||||||
|
// Test case: An unsafe subobject (@tensor) blocks inlining of the containing object (@list).
|
||||||
|
// Note that we can check just the initializer -- if we inlined the slot, then
|
||||||
|
// we would have eliminated the slot from the initializer.
|
||||||
|
// Also, the initializer is verified to match the list of global slots in the
|
||||||
|
// module. So it is a nice one-stop-shop.
|
||||||
|
|
||||||
|
torch.global_slot "private" @tensor : !torch.tensor
|
||||||
|
torch.global_slot "private" @list : !torch.list<tensor>
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK-NEXT: @tensor(%{{.*}} : !torch.tensor)
|
||||||
|
// CHECK-NEXT: @list(%{{.*}} : !torch.list<tensor>)
|
||||||
|
// CHECK-NEXT: ]
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
|
||||||
|
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@tensor(%0 : !torch.tensor)
|
||||||
|
@list(%1 : !torch.list<tensor>)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
func.func @forward() {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.global_slot.get @list : !torch.list<tensor>
|
||||||
|
%tensor = torch.global_slot.get @tensor : !torch.tensor
|
||||||
|
torch.aten.relu_ %tensor : !torch.tensor -> !torch.tensor
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// Transform aspect of the pass.
|
||||||
|
|
||||||
|
// Test case: Most basic case that can be inlined.
|
||||||
|
|
||||||
|
// CHECK-NOT: @slot0
|
||||||
|
torch.global_slot "private" @slot0 : !torch.int
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK-NEXT ]
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.constant.int 1
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@slot0(%0 : !torch.int)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @forward() {
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: return
|
||||||
|
func.func @forward() {
|
||||||
|
%0 = torch.global_slot.get @slot0 : !torch.int
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: Shared objects in object graph shared between two initial values.
|
||||||
|
|
||||||
|
torch.global_slot "private" @tensor : !torch.tensor
|
||||||
|
torch.global_slot "private" @list_of_tensor : !torch.list<tensor>
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK-NEXT ]
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
|
||||||
|
%1 = torch.prim.ListConstruct %0 : (!torch.tensor) -> !torch.list<tensor>
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@tensor(%0 : !torch.tensor)
|
||||||
|
@list_of_tensor(%1 : !torch.list<tensor>)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @forward() {
|
||||||
|
// CHECK: %[[T0:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor
|
||||||
|
// CHECK: %[[T1:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[T1]] : (!torch.tensor) -> !torch.list<tensor>
|
||||||
|
// CHECK: return
|
||||||
|
func.func @forward() {
|
||||||
|
%0 = torch.global_slot.get @tensor : !torch.tensor
|
||||||
|
%1 = torch.global_slot.get @list_of_tensor : !torch.tensor
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test case: Adjusting static info.
|
||||||
|
|
||||||
|
// CHECK-NOT: @tensor
|
||||||
|
torch.global_slot "private" @tensor : !torch.tensor
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.global_slot.module_initializer {
|
||||||
|
// CHECK: torch.initialize.global_slots [
|
||||||
|
// CHECK-NEXT ]
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],f32>
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@tensor(%0 : !torch.tensor<[],f32>)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @forward() {
|
||||||
|
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<f32>) : !torch.tensor<[],f32>
|
||||||
|
// CHECK: %[[CASTED:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.tensor<[],f32> to !torch.tensor
|
||||||
|
func.func @forward() {
|
||||||
|
%0 = torch.global_slot.get @tensor : !torch.tensor
|
||||||
|
return
|
||||||
|
}
|
|
@ -1,37 +0,0 @@
|
||||||
// RUN: torch-mlir-opt -torch-inline-global-slots -split-input-file %s | FileCheck %s
|
|
||||||
|
|
||||||
// CHECK-NOT: @readonly
|
|
||||||
torch.global_slot "private" @readonly : !torch.tensor {
|
|
||||||
%0 = torch.tensor.literal(dense<0.0> : tensor<1xf32>) : !torch.tensor
|
|
||||||
torch.global_slot.init %0 : !torch.tensor
|
|
||||||
}
|
|
||||||
// CHECK-LABEL: torch.global_slot @public
|
|
||||||
torch.global_slot @public : !torch.tensor {
|
|
||||||
%0 = torch.tensor.literal(dense<0.0> : tensor<2xf32>) : !torch.tensor
|
|
||||||
torch.global_slot.init %0 : !torch.tensor
|
|
||||||
}
|
|
||||||
// CHECK-LABEL: torch.global_slot "private" @mutated
|
|
||||||
torch.global_slot "private" @mutated : !torch.tensor {
|
|
||||||
%0 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor
|
|
||||||
torch.global_slot.init %0 : !torch.tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
|
|
||||||
func.func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
|
|
||||||
// Inlined.
|
|
||||||
// CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
|
|
||||||
%0 = torch.global_slot.get @readonly : !torch.tensor
|
|
||||||
|
|
||||||
// Not inlined: potentially mutated by externals.
|
|
||||||
// CHECK: %[[PUBLIC:.*]] = torch.global_slot.get @public : !torch.tensor
|
|
||||||
%1 = torch.global_slot.get @public : !torch.tensor
|
|
||||||
|
|
||||||
// Not inlined: potentially mutated internally.
|
|
||||||
// CHECK: torch.global_slot.set @mutated = %[[READONLY]] : !torch.tensor
|
|
||||||
// CHECK: %[[MUTATED:.*]] = torch.global_slot.get @mutated : !torch.tensor
|
|
||||||
torch.global_slot.set @mutated = %0 : !torch.tensor
|
|
||||||
%2 = torch.global_slot.get @mutated : !torch.tensor
|
|
||||||
|
|
||||||
// CHECK: return %[[READONLY]], %[[PUBLIC]], %[[MUTATED]] : !torch.tensor, !torch.tensor, !torch.tensor
|
|
||||||
return %0, %1, %2 : !torch.tensor, !torch.tensor, !torch.tensor
|
|
||||||
}
|
|
|
@ -179,3 +179,89 @@ func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1
|
||||||
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
|
%1 = torch.copy.to_vtensor %0 : !torch.vtensor<[1],f32>
|
||||||
return %1 : !torch.vtensor<[1],f32>
|
return %1 : !torch.vtensor<[1],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// There must be only one module initialize.
|
||||||
|
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// expected-error @+1 {{there must be only one global slot initializer}}
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Initialized slot missing, or or non-existent slots initialized.
|
||||||
|
|
||||||
|
// expected-note @+1 {{missing global slot initializer for @slot0}}
|
||||||
|
torch.global_slot @slot0 : !torch.int
|
||||||
|
// expected-note @+1 {{missing global slot initializer for @slot1}}
|
||||||
|
torch.global_slot @slot1 : !torch.int
|
||||||
|
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.constant.int 1
|
||||||
|
%1 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
|
||||||
|
%2 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],unk>
|
||||||
|
// expected-error @below {{must have one initializer for each global slot in the module}}
|
||||||
|
// expected-note @below {{unexpected global slot initializer for non-existent global slot @nonexistent_slot0}}
|
||||||
|
// expected-note @below {{unexpected global slot initializer for non-existent global slot @nonexistent_slot1}}
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@nonexistent_slot0(%0 : !torch.int)
|
||||||
|
@nonexistent_slot1(%0 : !torch.int)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Duplicate initialization of global slot.
|
||||||
|
|
||||||
|
torch.global_slot @slot0 : !torch.int
|
||||||
|
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.constant.int 1
|
||||||
|
// expected-error @+1 {{duplicate initialization of global slot: @slot0}}
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@slot0(%0 : !torch.int)
|
||||||
|
@slot0(%0 : !torch.int)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Subtyping checks.
|
||||||
|
|
||||||
|
torch.global_slot @tensor : !torch.tensor
|
||||||
|
torch.global_slot @initialized_with_refined : !torch.tensor
|
||||||
|
torch.global_slot @error_initialized_with_derefined : !torch.tensor<[],unk>
|
||||||
|
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%1 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
|
||||||
|
%2 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor<[],unk>
|
||||||
|
// expected-error @below {{initial value for global slot @error_initialized_with_derefined has type '!torch.tensor' which is not within the bound '!torch.tensor<[],unk>'}}
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@tensor(%1 : !torch.tensor)
|
||||||
|
@initialized_with_refined(%2 : !torch.tensor<[],unk>)
|
||||||
|
@error_initialized_with_derefined(%1 : !torch.tensor)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Restricted set of ops in the module initializer.
|
||||||
|
|
||||||
|
torch.global_slot @tensor : !torch.tensor
|
||||||
|
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
%0 = torch.tensor.literal(dense<0.0> : tensor<f32>) : !torch.tensor
|
||||||
|
// expected-error @+1 {{'torch.aten.mul.Tensor' op is not allowed in a module initializer}}
|
||||||
|
%1 = torch.aten.mul.Tensor %0, %0 : !torch.tensor, !torch.tensor -> !torch.tensor
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
@tensor(%1 : !torch.tensor)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
|
@ -26,3 +26,11 @@ func.func @unresolved_operator(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.int)
|
||||||
torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32>
|
torch.operator "aten.mul.Scalar"(%arg0, %arg1) : (!torch.vtensor<[],f32>, !torch.int) -> !torch.vtensor<[],f32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// expected-error@+1 {{unsupported by backend lowering: module initializers}}
|
||||||
|
torch.global_slot.module_initializer {
|
||||||
|
torch.initialize.global_slots [
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue