mirror of https://github.com/llvm/torch-mlir
Bump forward and refactor inline global slots to no longer track via symlinks. This appears to make the tests past until we manage to remove torchscript work.pull/3700/head
parent
b35675a78e
commit
6934ab81b0
|
@ -1 +1 @@
|
|||
Subproject commit f9031f00f2c90bc0af274b45ec3e169b5250a688
|
||||
Subproject commit b6603e1bf11dee4761e49af6581c8b8f074b705d
|
|
@ -49,16 +49,15 @@ using namespace mlir::torch::Torch;
|
|||
/// a single module. If we had to support complex nested symbol references, we
|
||||
/// would probably want to go through the effort to indirect through the symbol
|
||||
/// tables to make things clearer.
|
||||
class FlatSymbolRefProgramPoint
|
||||
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
|
||||
FlatSymbolRefAttr> {
|
||||
class FlatSymbolRefLatticeAnchor
|
||||
: public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
|
||||
public:
|
||||
using Base::Base;
|
||||
void print(raw_ostream &os) const override {
|
||||
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
|
||||
os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")";
|
||||
}
|
||||
Location getLoc() const override {
|
||||
return UnknownLoc::get(getValue().getContext());
|
||||
return UnknownLoc::get(getValue()->getContext());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
|||
/// State tracking if an IR construct is "safe".
|
||||
///
|
||||
/// This state is tracked on Value's and also on global slots (via a
|
||||
/// FlatSymbolRefProgramPoint).
|
||||
/// FlatSymbolRefLatticeAnchor).
|
||||
///
|
||||
/// In this context, "safe" means that the object is safe to inline.
|
||||
/// This covers a few concepts
|
||||
|
@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
|||
/// unsafe
|
||||
class InlineGlobalSlotsAnalysisState : public AnalysisState {
|
||||
public:
|
||||
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
|
||||
InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
|
||||
(void)setSafe();
|
||||
}
|
||||
|
||||
|
@ -147,33 +146,33 @@ private:
|
|||
|
||||
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
|
||||
: DataFlowAnalysis(solver) {
|
||||
registerPointKind<FlatSymbolRefProgramPoint>();
|
||||
registerAnchorKind<FlatSymbolRefLatticeAnchor>();
|
||||
}
|
||||
|
||||
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
||||
auto walkResult = top->walk([this](Operation *op) {
|
||||
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||
getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
|
||||
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||
propagateIfChanged(state,
|
||||
state->setSafe(globalSlot.getVisibility() !=
|
||||
SymbolTable::Visibility::Public));
|
||||
}
|
||||
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
|
||||
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
|
||||
globalSlotSet, globalSlotSet.getSlotAttr());
|
||||
|
||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||
getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||
globalSlotSet.getSlotAttr()));
|
||||
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||
propagateIfChanged(state, state->setSafe(false));
|
||||
}
|
||||
// Save the InitializeGlobalSlotsOp for later referencee
|
||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||
initializeGlobalSlotsOp = initialize;
|
||||
}
|
||||
for (Value result : op->getResults()) {
|
||||
if (failed(visit(result)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
if (failed(visit(op)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted())
|
||||
|
@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
|||
}
|
||||
|
||||
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
||||
if (Value value = dyn_cast<Value>(point)) {
|
||||
bool isSafe = isValueSafeTransferFunction(value);
|
||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||
propagateIfChanged(state, state->setSafe(isSafe));
|
||||
if (auto op = dyn_cast<Operation *>(point)) {
|
||||
for (auto value : op->getResults()) {
|
||||
bool isSafe = isValueSafeTransferFunction(value);
|
||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||
propagateIfChanged(state, state->setSafe(isSafe));
|
||||
|
||||
// Handle GlobalSlotGetOp's.
|
||||
if (auto opResult = dyn_cast<OpResult>(value)) {
|
||||
if (auto globalSlotGet =
|
||||
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
||||
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||
globalSlotGet.getSlotAttr());
|
||||
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||
flatSymbolRefPoint, globalSlotGet.getResult());
|
||||
auto *globalState =
|
||||
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
|
||||
propagateIfChanged(globalState,
|
||||
globalState->incorporateSafetyOfUse(valueState));
|
||||
// Handle GlobalSlotGetOp's.
|
||||
if (auto opResult = dyn_cast<OpResult>(value)) {
|
||||
if (auto globalSlotGet =
|
||||
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
||||
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
|
||||
globalSlotGet, globalSlotGet.getSlotAttr());
|
||||
auto *flatSymbolRefPoint =
|
||||
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
|
||||
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||
globalSlot, globalSlotGet.getResult());
|
||||
auto *globalState =
|
||||
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
|
||||
propagateIfChanged(globalState,
|
||||
globalState->incorporateSafetyOfUse(valueState));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
|
||||
if (auto *flatSymbolRefPoint =
|
||||
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
|
||||
if (initializeGlobalSlotsOp) {
|
||||
auto it =
|
||||
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
|
||||
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
|
||||
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
|
||||
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
|
||||
auto *flatSymbolRefState =
|
||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
|
||||
flatSymbolRefPoint);
|
||||
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||
propagateIfChanged(valueState,
|
||||
valueState->setSafe(flatSymbolRefState->isSafe));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(
|
||||
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
// This is only a member function to access protected get* functions.
|
||||
|
@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
|
|||
// safe. This covers, for example, view-like ops that create aliases.
|
||||
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
|
||||
llvm::all_of(op->getResults(), [&](Value result) {
|
||||
auto *state =
|
||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
|
||||
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||
value.getDefiningOp(), result);
|
||||
return state->isSafe;
|
||||
}))
|
||||
continue;
|
||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||
auto symName = cast<FlatSymbolRefAttr>(
|
||||
initialize.getSlotSymNames()[use.getOperandNumber()]);
|
||||
auto globalSlot =
|
||||
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
|
||||
|
||||
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
|
||||
value.getDefiningOp(),
|
||||
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||
if (state->isSafe)
|
||||
continue;
|
||||
}
|
||||
|
@ -299,8 +284,7 @@ class InlineGlobalSlotsPass
|
|||
module->walk([&](Operation *op) {
|
||||
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
||||
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
|
||||
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
|
||||
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||
state->print(llvm::dbgs());
|
||||
llvm::dbgs() << ": "
|
||||
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
|
||||
|
@ -334,13 +318,16 @@ class InlineGlobalSlotsPass
|
|||
auto slotSymName =
|
||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||
Value operand = initialize.getOperand(i);
|
||||
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
|
||||
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
|
||||
initialize, slotSymName);
|
||||
|
||||
auto symbolRefPoint =
|
||||
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
|
||||
auto *state =
|
||||
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
||||
// We roll the analysis of whether a slot is set or public into the
|
||||
// main dataflow analysis, so we need to check the slot's
|
||||
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
|
||||
// FlatSymbolRefLatticeAnchor itself to see if it is safe to inline.
|
||||
// For example, a public !torch.int is not safe to inline, even though
|
||||
// it is a value-semantic type and so the actual initializer value
|
||||
// itself is conceptually safe to inline.
|
||||
|
|
|
@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
|
|||
// CHECK: %[[T_5:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[T_6:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[T_7:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
|
||||
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
|
|||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %int4 = torch.constant.int 4
|
||||
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
|
|||
// CHECK: %none = torch.constant.none
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
|
||||
|
@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
|
|||
// CHECK: %none = torch.constant.none
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
|
|||
// CHECK: %none = torch.constant.none
|
||||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
|||
// CHECK: %int0 = torch.constant.int 0
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %int2 = torch.constant.int 2
|
||||
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
|
||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
|
||||
// CHECK: %c0 = arith.constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue