mirror of https://github.com/llvm/torch-mlir
94 lines
3.3 KiB
C++
94 lines
3.3 KiB
C++
//===- InlineGlobalSlots.cpp -------------------------------------*- C++-*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::torch;
|
|
using namespace mlir::torch::Torch;
|
|
|
|
namespace {
|
|
class InlineGlobalSlotsPass
|
|
: public InlineGlobalSlotsBase<InlineGlobalSlotsPass> {
|
|
void runOnOperation() override {
|
|
ModuleOp module = getOperation();
|
|
SymbolTable symbolTable(module);
|
|
auto uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
|
|
if (!uses) {
|
|
module.emitError() << "cannot analyze symbol uses";
|
|
return signalPassFailure();
|
|
}
|
|
// Find all the global slots potentially written from within the module.
|
|
// (we handle the case of non-private symbols later).
|
|
DenseSet<Torch::GlobalSlotOp> potentiallyWrittenGlobalSlots;
|
|
for (const SymbolTable::SymbolUse &use : *uses) {
|
|
auto flatSymbolRef = use.getSymbolRef().dyn_cast<FlatSymbolRefAttr>();
|
|
if (!flatSymbolRef) {
|
|
use.getUser()->emitError() << "unimplemented: nested SymbolRef's";
|
|
return signalPassFailure();
|
|
}
|
|
auto globalSlot =
|
|
symbolTable.lookup<Torch::GlobalSlotOp>(flatSymbolRef.getValue());
|
|
|
|
if (!globalSlot)
|
|
continue;
|
|
if (isa<Torch::GlobalSlotGetOp>(use.getUser()))
|
|
continue;
|
|
|
|
potentiallyWrittenGlobalSlots.insert(globalSlot);
|
|
}
|
|
|
|
DenseSet<Operation *> toErase;
|
|
// Inline all the global slots that are not potentially written.
|
|
for (const SymbolTable::SymbolUse &use : *uses) {
|
|
auto flatSymbolRef = use.getSymbolRef().cast<FlatSymbolRefAttr>();
|
|
auto globalSlot =
|
|
symbolTable.lookup<Torch::GlobalSlotOp>(flatSymbolRef.getValue());
|
|
if (!globalSlot)
|
|
continue;
|
|
// And external user might write to the global slot.
|
|
if (!globalSlot.isPrivate())
|
|
continue;
|
|
// An internal user exists which might write to the global slot.
|
|
if (potentiallyWrittenGlobalSlots.contains(globalSlot))
|
|
continue;
|
|
auto globalSlotGet = cast<Torch::GlobalSlotGetOp>(use.getUser());
|
|
OpBuilder builder(globalSlotGet);
|
|
BlockAndValueMapping mapper;
|
|
for (Operation &op : globalSlot.getBody()->without_terminator())
|
|
builder.clone(op, mapper);
|
|
Value cloned = mapper.lookup(
|
|
cast<GlobalSlotInitOp>(globalSlot.getBody()->getTerminator())
|
|
.getOperand());
|
|
globalSlotGet.replaceAllUsesWith(cloned);
|
|
toErase.insert(globalSlotGet);
|
|
toErase.insert(globalSlot);
|
|
}
|
|
|
|
for (Operation *op : toErase)
|
|
op->erase();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::torch::Torch::createInlineGlobalSlotsPass() {
|
|
return std::make_unique<InlineGlobalSlotsPass>();
|
|
}
|