pull/3792/head
Marius Brehler 2024-10-14 15:00:45 +02:00 committed by GitHub
parent b176939808
commit edd1bbec46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 8 deletions

@ -1 +1 @@
Subproject commit e813750354bbc08551cf23ff559a54b4a9ea1f29
Subproject commit c13f806f17ac61961015e38b69c8b39ba7d454ac

View File

@ -132,7 +132,7 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {
public:
InlineGlobalSlotsAnalysis(DataFlowSolver &solver);
LogicalResult initialize(Operation *top) override;
LogicalResult visit(ProgramPoint point) override;
LogicalResult visit(ProgramPoint *point) override;
private:
/// The local transfer function determining the safety of `value`.
@ -170,7 +170,7 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize;
}
if (failed(visit(op)))
if (failed(visit(getProgramPointAfter(op))))
return WalkResult::interrupt();
return WalkResult::advance();
@ -180,8 +180,11 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
return success();
}
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (auto op = dyn_cast<Operation *>(point)) {
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint *point) {
if (point->isBlockStart())
return success();
if (auto op = point->getPrevOp()) {
for (auto value : op->getResults()) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
@ -196,7 +199,7 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
auto *flatSymbolRefPoint =
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
globalSlot, globalSlotGet.getResult());
getProgramPointAfter(globalSlot), globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
@ -223,7 +226,7 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
llvm::all_of(op->getResults(), [&](Value result) {
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value.getDefiningOp(), result);
getProgramPointAfter(value.getDefiningOp()), result);
return state->isSafe;
}))
continue;
@ -234,7 +237,7 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value.getDefiningOp(),
getProgramPointAfter(value.getDefiningOp()),
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
if (state->isSafe)
continue;