Bump forward and refactor inline global slots to no longer track via
symlinks. This appears to make the tests past until we manage to remove
torchscript work.
pull/3700/head
Rob Suderman 2024-09-10 08:57:15 -07:00 committed by GitHub
parent b35675a78e
commit 6934ab81b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 74 deletions

@ -1 +1 @@
Subproject commit f9031f00f2c90bc0af274b45ec3e169b5250a688
Subproject commit b6603e1bf11dee4761e49af6581c8b8f074b705d

View File

@ -49,16 +49,15 @@ using namespace mlir::torch::Torch;
/// a single module. If we had to support complex nested symbol references, we
/// would probably want to go through the effort to indirect through the symbol
/// tables to make things clearer.
class FlatSymbolRefProgramPoint
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
FlatSymbolRefAttr> {
class FlatSymbolRefLatticeAnchor
: public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
public:
using Base::Base;
void print(raw_ostream &os) const override {
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")";
}
Location getLoc() const override {
return UnknownLoc::get(getValue().getContext());
return UnknownLoc::get(getValue()->getContext());
}
};
@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// State tracking if an IR construct is "safe".
///
/// This state is tracked on Value's and also on global slots (via a
/// FlatSymbolRefProgramPoint).
/// FlatSymbolRefLatticeAnchor).
///
/// In this context, "safe" means that the object is safe to inline.
/// This covers a few concepts
@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
(void)setSafe();
}
@ -147,33 +146,33 @@ private:
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
registerPointKind<FlatSymbolRefProgramPoint>();
registerAnchorKind<FlatSymbolRefLatticeAnchor>();
}
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
auto walkResult = top->walk([this](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state,
state->setSafe(globalSlot.getVisibility() !=
SymbolTable::Visibility::Public));
}
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotSet, globalSlotSet.getSlotAttr());
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotSet.getSlotAttr()));
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state, state->setSafe(false));
}
// Save the InitializeGlobalSlotsOp for later referencee
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize;
}
for (Value result : op->getResults()) {
if (failed(visit(result)))
return WalkResult::interrupt();
}
if (failed(visit(op)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
}
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
if (Value value = dyn_cast<Value>(point)) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));
if (auto op = dyn_cast<Operation *>(point)) {
for (auto value : op->getResults()) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(state, state->setSafe(isSafe));
// Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotGet.getSlotAttr());
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
flatSymbolRefPoint, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
// Handle GlobalSlotGetOp's.
if (auto opResult = dyn_cast<OpResult>(value)) {
if (auto globalSlotGet =
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotGet, globalSlotGet.getSlotAttr());
auto *flatSymbolRefPoint =
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
globalSlot, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
globalState->incorporateSafetyOfUse(valueState));
}
}
}
}
return success();
}
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
if (auto *flatSymbolRefPoint =
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
if (initializeGlobalSlotsOp) {
auto it =
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
auto *flatSymbolRefState =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
flatSymbolRefPoint);
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
propagateIfChanged(valueState,
valueState->setSafe(flatSymbolRefState->isSafe));
}
return success();
}
}
LLVM_DEBUG(
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
return failure();
return success();
}
// This is only a member function to access protected get* functions.
@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
// safe. This covers, for example, view-like ops that create aliases.
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
llvm::all_of(op->getResults(), [&](Value result) {
auto *state =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value.getDefiningOp(), result);
return state->isSafe;
}))
continue;
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
auto symName = cast<FlatSymbolRefAttr>(
initialize.getSlotSymNames()[use.getOperandNumber()]);
auto globalSlot =
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
value.getDefiningOp(),
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
if (state->isSafe)
continue;
}
@ -299,8 +284,7 @@ class InlineGlobalSlotsPass
module->walk([&](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
state->print(llvm::dbgs());
llvm::dbgs() << ": "
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
@ -334,13 +318,16 @@ class InlineGlobalSlotsPass
auto slotSymName =
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
initialize, slotSymName);
auto symbolRefPoint =
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the
// main dataflow analysis, so we need to check the slot's
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
// FlatSymbolRefLatticeAnchor itself to see if it is safe to inline.
// For example, a public !torch.int is not safe to inline, even though
// it is a value-semantic type and so the actual initializer value
// itself is conceptually safe to inline.

View File

@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_5:.*]] = torch.constant.int 1
// CHECK: %[[T_6:.*]] = torch.constant.int 4
// CHECK: %[[T_7:.*]] = torch.constant.int 3
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int4 = torch.constant.int 4
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %c0 = arith.constant 0 : index