Bump forward and refactor inline global slots to no longer track via
symlinks. This appears to make the tests past until we manage to remove
torchscript work.
pull/3700/head
Rob Suderman 2024-09-10 08:57:15 -07:00 committed by GitHub
parent b35675a78e
commit 6934ab81b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 74 deletions

@ -1 +1 @@
Subproject commit f9031f00f2c90bc0af274b45ec3e169b5250a688 Subproject commit b6603e1bf11dee4761e49af6581c8b8f074b705d

View File

@ -49,16 +49,15 @@ using namespace mlir::torch::Torch;
/// a single module. If we had to support complex nested symbol references, we /// a single module. If we had to support complex nested symbol references, we
/// would probably want to go through the effort to indirect through the symbol /// would probably want to go through the effort to indirect through the symbol
/// tables to make things clearer. /// tables to make things clearer.
class FlatSymbolRefProgramPoint class FlatSymbolRefLatticeAnchor
: public GenericProgramPointBase<FlatSymbolRefProgramPoint, : public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
FlatSymbolRefAttr> {
public: public:
using Base::Base; using Base::Base;
void print(raw_ostream &os) const override { void print(raw_ostream &os) const override {
os << "FlatSymbolRefProgramPoint(" << getValue() << ")"; os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")";
} }
Location getLoc() const override { Location getLoc() const override {
return UnknownLoc::get(getValue().getContext()); return UnknownLoc::get(getValue()->getContext());
} }
}; };
@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// State tracking if an IR construct is "safe". /// State tracking if an IR construct is "safe".
/// ///
/// This state is tracked on Value's and also on global slots (via a /// This state is tracked on Value's and also on global slots (via a
/// FlatSymbolRefProgramPoint). /// FlatSymbolRefLatticeAnchor).
/// ///
/// In this context, "safe" means that the object is safe to inline. /// In this context, "safe" means that the object is safe to inline.
/// This covers a few concepts /// This covers a few concepts
@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// unsafe /// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState { class InlineGlobalSlotsAnalysisState : public AnalysisState {
public: public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) { InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
(void)setSafe(); (void)setSafe();
} }
@ -147,33 +146,33 @@ private:
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver) InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) { : DataFlowAnalysis(solver) {
registerPointKind<FlatSymbolRefProgramPoint>(); registerAnchorKind<FlatSymbolRefLatticeAnchor>();
} }
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
auto walkResult = top->walk([this](Operation *op) { auto walkResult = top->walk([this](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) { if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>( auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>( getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
propagateIfChanged(state, propagateIfChanged(state,
state->setSafe(globalSlot.getVisibility() != state->setSafe(globalSlot.getVisibility() !=
SymbolTable::Visibility::Public)); SymbolTable::Visibility::Public));
} }
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) { if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotSet, globalSlotSet.getSlotAttr());
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>( auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>( getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
globalSlotSet.getSlotAttr()));
propagateIfChanged(state, state->setSafe(false)); propagateIfChanged(state, state->setSafe(false));
} }
// Save the InitializeGlobalSlotsOp for later referencee // Save the InitializeGlobalSlotsOp for later referencee
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) { if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize; initializeGlobalSlotsOp = initialize;
} }
for (Value result : op->getResults()) { if (failed(visit(op)))
if (failed(visit(result))) return WalkResult::interrupt();
return WalkResult::interrupt();
}
return WalkResult::advance(); return WalkResult::advance();
}); });
if (walkResult.wasInterrupted()) if (walkResult.wasInterrupted())
@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
} }
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = dyn_cast<Value>(point)) { if (auto op = dyn_cast<Operation *>(point)) {
bool isSafe = isValueSafeTransferFunction(value); for (auto value : op->getResults()) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value); bool isSafe = isValueSafeTransferFunction(value);
propagateIfChanged(state, state->setSafe(isSafe)); auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));
// Handle GlobalSlotGetOp's. // Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) { if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet = if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) { dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>( auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotGet.getSlotAttr()); globalSlotGet, globalSlotGet.getSlotAttr());
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>( auto *flatSymbolRefPoint =
flatSymbolRefPoint, globalSlotGet.getResult()); getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *globalState = auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint); globalSlot, globalSlotGet.getResult());
propagateIfChanged(globalState, auto *globalState =
globalState->incorporateSafetyOfUse(valueState)); getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
} }
} }
}
return success(); return success();
}
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) {
auto it =
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
auto *flatSymbolRefState =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
flatSymbolRefPoint);
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(valueState,
valueState->setSafe(flatSymbolRefState->isSafe));
}
return success();
}
}
LLVM_DEBUG(
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
return failure();
} }
// This is only a member function to access protected get* functions. // This is only a member function to access protected get* functions.
@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
// safe. This covers, for example, view-like ops that create aliases. // safe. This covers, for example, view-like ops that create aliases.
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) && if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
llvm::all_of(op->getResults(), [&](Value result) { llvm::all_of(op->getResults(), [&](Value result) {
auto *state = auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result); value.getDefiningOp(), result);
return state->isSafe; return state->isSafe;
})) }))
continue; continue;
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) { if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
auto symName = cast<FlatSymbolRefAttr>( auto symName = cast<FlatSymbolRefAttr>(
initialize.getSlotSymNames()[use.getOperandNumber()]); initialize.getSlotSymNames()[use.getOperandNumber()]);
auto globalSlot =
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>( auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName)); value.getDefiningOp(),
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
if (state->isSafe) if (state->isSafe)
continue; continue;
} }
@ -299,8 +284,7 @@ class InlineGlobalSlotsPass
module->walk([&](Operation *op) { module->walk([&](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) { if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>( auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
solver.getProgramPoint<FlatSymbolRefProgramPoint>( solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
state->print(llvm::dbgs()); state->print(llvm::dbgs());
llvm::dbgs() << ": " llvm::dbgs() << ": "
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()) << FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
@ -334,13 +318,16 @@ class InlineGlobalSlotsPass
auto slotSymName = auto slotSymName =
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]); cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
Value operand = initialize.getOperand(i); Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>( auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i])); initialize, slotSymName);
auto symbolRefPoint =
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *state = auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint); solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the // We roll the analysis of whether a slot is set or public into the
// main dataflow analysis, so we need to check the slot's // main dataflow analysis, so we need to check the slot's
// FlatSymbolRefProgramPoint itself to see if it is safe to inline. // FlatSymbolRefLatticeAnchor itself to see if it is safe to inline.
// For example, a public !torch.int is not safe to inline, even though // For example, a public !torch.int is not safe to inline, even though
// it is a value-semantic type and so the actual initializer value // it is a value-semantic type and so the actual initializer value
// itself is conceptually safe to inline. // itself is conceptually safe to inline.

View File

@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_5:.*]] = torch.constant.int 1 // CHECK: %[[T_5:.*]] = torch.constant.int 1
// CHECK: %[[T_6:.*]] = torch.constant.int 4 // CHECK: %[[T_6:.*]] = torch.constant.int 4
// CHECK: %[[T_7:.*]] = torch.constant.int 3 // CHECK: %[[T_7:.*]] = torch.constant.int 3
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %int2 = torch.constant.int 2 // CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %int4 = torch.constant.int 4 // CHECK: %int4 = torch.constant.int 4
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %none = torch.constant.none // CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0 // CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> // CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %none = torch.constant.none // CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0 // CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2 // CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %none = torch.constant.none // CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0 // CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2 // CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %int0 = torch.constant.int 0 // CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1 // CHECK: %int1 = torch.constant.int 1
// CHECK: %int2 = torch.constant.int 2 // CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32> // CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32> // CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %c0 = arith.constant 0 : index // CHECK: %c0 = arith.constant 0 : index