mirror of https://github.com/llvm/torch-mlir
Bump forward and refactor inline global slots to no longer track via symlinks. This appears to make the tests past until we manage to remove torchscript work.pull/3700/head
parent
b35675a78e
commit
6934ab81b0
|
@ -1 +1 @@
|
||||||
Subproject commit f9031f00f2c90bc0af274b45ec3e169b5250a688
|
Subproject commit b6603e1bf11dee4761e49af6581c8b8f074b705d
|
|
@ -49,16 +49,15 @@ using namespace mlir::torch::Torch;
|
||||||
/// a single module. If we had to support complex nested symbol references, we
|
/// a single module. If we had to support complex nested symbol references, we
|
||||||
/// would probably want to go through the effort to indirect through the symbol
|
/// would probably want to go through the effort to indirect through the symbol
|
||||||
/// tables to make things clearer.
|
/// tables to make things clearer.
|
||||||
class FlatSymbolRefProgramPoint
|
class FlatSymbolRefLatticeAnchor
|
||||||
: public GenericProgramPointBase<FlatSymbolRefProgramPoint,
|
: public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
|
||||||
FlatSymbolRefAttr> {
|
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
void print(raw_ostream &os) const override {
|
void print(raw_ostream &os) const override {
|
||||||
os << "FlatSymbolRefProgramPoint(" << getValue() << ")";
|
os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")";
|
||||||
}
|
}
|
||||||
Location getLoc() const override {
|
Location getLoc() const override {
|
||||||
return UnknownLoc::get(getValue().getContext());
|
return UnknownLoc::get(getValue()->getContext());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -84,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
||||||
/// State tracking if an IR construct is "safe".
|
/// State tracking if an IR construct is "safe".
|
||||||
///
|
///
|
||||||
/// This state is tracked on Value's and also on global slots (via a
|
/// This state is tracked on Value's and also on global slots (via a
|
||||||
/// FlatSymbolRefProgramPoint).
|
/// FlatSymbolRefLatticeAnchor).
|
||||||
///
|
///
|
||||||
/// In this context, "safe" means that the object is safe to inline.
|
/// In this context, "safe" means that the object is safe to inline.
|
||||||
/// This covers a few concepts
|
/// This covers a few concepts
|
||||||
|
@ -93,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
|
||||||
/// unsafe
|
/// unsafe
|
||||||
class InlineGlobalSlotsAnalysisState : public AnalysisState {
|
class InlineGlobalSlotsAnalysisState : public AnalysisState {
|
||||||
public:
|
public:
|
||||||
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
|
InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
|
||||||
(void)setSafe();
|
(void)setSafe();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,33 +146,33 @@ private:
|
||||||
|
|
||||||
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
|
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
|
||||||
: DataFlowAnalysis(solver) {
|
: DataFlowAnalysis(solver) {
|
||||||
registerPointKind<FlatSymbolRefProgramPoint>();
|
registerAnchorKind<FlatSymbolRefLatticeAnchor>();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
||||||
auto walkResult = top->walk([this](Operation *op) {
|
auto walkResult = top->walk([this](Operation *op) {
|
||||||
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
||||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||||
getProgramPoint<FlatSymbolRefProgramPoint>(
|
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||||
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
|
|
||||||
propagateIfChanged(state,
|
propagateIfChanged(state,
|
||||||
state->setSafe(globalSlot.getVisibility() !=
|
state->setSafe(globalSlot.getVisibility() !=
|
||||||
SymbolTable::Visibility::Public));
|
SymbolTable::Visibility::Public));
|
||||||
}
|
}
|
||||||
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
|
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
|
||||||
|
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
|
||||||
|
globalSlotSet, globalSlotSet.getSlotAttr());
|
||||||
|
|
||||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||||
getProgramPoint<FlatSymbolRefProgramPoint>(
|
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||||
globalSlotSet.getSlotAttr()));
|
|
||||||
propagateIfChanged(state, state->setSafe(false));
|
propagateIfChanged(state, state->setSafe(false));
|
||||||
}
|
}
|
||||||
// Save the InitializeGlobalSlotsOp for later referencee
|
// Save the InitializeGlobalSlotsOp for later referencee
|
||||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||||
initializeGlobalSlotsOp = initialize;
|
initializeGlobalSlotsOp = initialize;
|
||||||
}
|
}
|
||||||
for (Value result : op->getResults()) {
|
if (failed(visit(op)))
|
||||||
if (failed(visit(result)))
|
return WalkResult::interrupt();
|
||||||
return WalkResult::interrupt();
|
|
||||||
}
|
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
});
|
});
|
||||||
if (walkResult.wasInterrupted())
|
if (walkResult.wasInterrupted())
|
||||||
|
@ -182,50 +181,32 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
||||||
if (Value value = dyn_cast<Value>(point)) {
|
if (auto op = dyn_cast<Operation *>(point)) {
|
||||||
bool isSafe = isValueSafeTransferFunction(value);
|
for (auto value : op->getResults()) {
|
||||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
bool isSafe = isValueSafeTransferFunction(value);
|
||||||
propagateIfChanged(state, state->setSafe(isSafe));
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
||||||
|
propagateIfChanged(state, state->setSafe(isSafe));
|
||||||
|
|
||||||
// Handle GlobalSlotGetOp's.
|
// Handle GlobalSlotGetOp's.
|
||||||
if (auto opResult = dyn_cast<OpResult>(value)) {
|
if (auto opResult = dyn_cast<OpResult>(value)) {
|
||||||
if (auto globalSlotGet =
|
if (auto globalSlotGet =
|
||||||
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
dyn_cast<Torch::GlobalSlotGetOp>(opResult.getOwner())) {
|
||||||
auto *flatSymbolRefPoint = getProgramPoint<FlatSymbolRefProgramPoint>(
|
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
|
||||||
globalSlotGet.getSlotAttr());
|
globalSlotGet, globalSlotGet.getSlotAttr());
|
||||||
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
auto *flatSymbolRefPoint =
|
||||||
flatSymbolRefPoint, globalSlotGet.getResult());
|
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
|
||||||
auto *globalState =
|
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||||
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
|
globalSlot, globalSlotGet.getResult());
|
||||||
propagateIfChanged(globalState,
|
auto *globalState =
|
||||||
globalState->incorporateSafetyOfUse(valueState));
|
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
|
||||||
|
propagateIfChanged(globalState,
|
||||||
|
globalState->incorporateSafetyOfUse(valueState));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
|
||||||
if (auto *genericProgramPoint = dyn_cast<GenericProgramPoint *>(point)) {
|
|
||||||
if (auto *flatSymbolRefPoint =
|
|
||||||
dyn_cast<FlatSymbolRefProgramPoint>(genericProgramPoint)) {
|
|
||||||
if (initializeGlobalSlotsOp) {
|
|
||||||
auto it =
|
|
||||||
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
|
|
||||||
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
|
|
||||||
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
|
|
||||||
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
|
|
||||||
auto *flatSymbolRefState =
|
|
||||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
|
|
||||||
flatSymbolRefPoint);
|
|
||||||
auto *valueState = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
|
|
||||||
propagateIfChanged(valueState,
|
|
||||||
valueState->setSafe(flatSymbolRefState->isSafe));
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
LLVM_DEBUG(
|
|
||||||
{ llvm::dbgs() << "visit failing because of: " << point << "\n"; });
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is only a member function to access protected get* functions.
|
// This is only a member function to access protected get* functions.
|
||||||
|
@ -241,16 +222,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
|
||||||
// safe. This covers, for example, view-like ops that create aliases.
|
// safe. This covers, for example, view-like ops that create aliases.
|
||||||
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
|
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
|
||||||
llvm::all_of(op->getResults(), [&](Value result) {
|
llvm::all_of(op->getResults(), [&](Value result) {
|
||||||
auto *state =
|
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
|
value.getDefiningOp(), result);
|
||||||
return state->isSafe;
|
return state->isSafe;
|
||||||
}))
|
}))
|
||||||
continue;
|
continue;
|
||||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||||
auto symName = cast<FlatSymbolRefAttr>(
|
auto symName = cast<FlatSymbolRefAttr>(
|
||||||
initialize.getSlotSymNames()[use.getOperandNumber()]);
|
initialize.getSlotSymNames()[use.getOperandNumber()]);
|
||||||
|
auto globalSlot =
|
||||||
|
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
|
||||||
|
|
||||||
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||||
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
|
value.getDefiningOp(),
|
||||||
|
getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||||
if (state->isSafe)
|
if (state->isSafe)
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -299,8 +284,7 @@ class InlineGlobalSlotsPass
|
||||||
module->walk([&](Operation *op) {
|
module->walk([&](Operation *op) {
|
||||||
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
|
||||||
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
|
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
|
||||||
solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
|
||||||
FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())));
|
|
||||||
state->print(llvm::dbgs());
|
state->print(llvm::dbgs());
|
||||||
llvm::dbgs() << ": "
|
llvm::dbgs() << ": "
|
||||||
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
|
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
|
||||||
|
@ -334,13 +318,16 @@ class InlineGlobalSlotsPass
|
||||||
auto slotSymName =
|
auto slotSymName =
|
||||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||||
Value operand = initialize.getOperand(i);
|
Value operand = initialize.getOperand(i);
|
||||||
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
|
||||||
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
|
initialize, slotSymName);
|
||||||
|
|
||||||
|
auto symbolRefPoint =
|
||||||
|
solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
|
||||||
auto *state =
|
auto *state =
|
||||||
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
||||||
// We roll the analysis of whether a slot is set or public into the
|
// We roll the analysis of whether a slot is set or public into the
|
||||||
// main dataflow analysis, so we need to check the slot's
|
// main dataflow analysis, so we need to check the slot's
|
||||||
// FlatSymbolRefProgramPoint itself to see if it is safe to inline.
|
// FlatSymbolRefLatticeAnchor itself to see if it is safe to inline.
|
||||||
// For example, a public !torch.int is not safe to inline, even though
|
// For example, a public !torch.int is not safe to inline, even though
|
||||||
// it is a value-semantic type and so the actual initializer value
|
// it is a value-semantic type and so the actual initializer value
|
||||||
// itself is conceptually safe to inline.
|
// itself is conceptually safe to inline.
|
||||||
|
|
|
@ -259,7 +259,6 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
|
||||||
// CHECK: %[[T_5:.*]] = torch.constant.int 1
|
// CHECK: %[[T_5:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[T_6:.*]] = torch.constant.int 4
|
// CHECK: %[[T_6:.*]] = torch.constant.int 4
|
||||||
// CHECK: %[[T_7:.*]] = torch.constant.int 3
|
// CHECK: %[[T_7:.*]] = torch.constant.int 3
|
||||||
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
|
|
||||||
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
@ -295,7 +294,6 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
|
||||||
// CHECK: %int2 = torch.constant.int 2
|
// CHECK: %int2 = torch.constant.int 2
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %int4 = torch.constant.int 4
|
// CHECK: %int4 = torch.constant.int 4
|
||||||
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
|
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
@ -336,7 +334,6 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
|
||||||
// CHECK: %none = torch.constant.none
|
// CHECK: %none = torch.constant.none
|
||||||
// CHECK: %int0 = torch.constant.int 0
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
|
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
|
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
|
||||||
|
@ -367,7 +364,6 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
|
||||||
// CHECK: %none = torch.constant.none
|
// CHECK: %none = torch.constant.none
|
||||||
// CHECK: %int0 = torch.constant.int 0
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
|
|
||||||
// CHECK: %int2 = torch.constant.int 2
|
// CHECK: %int2 = torch.constant.int 2
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
@ -402,7 +398,6 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
|
||||||
// CHECK: %none = torch.constant.none
|
// CHECK: %none = torch.constant.none
|
||||||
// CHECK: %int0 = torch.constant.int 0
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
|
|
||||||
// CHECK: %int2 = torch.constant.int 2
|
// CHECK: %int2 = torch.constant.int 2
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
@ -438,10 +433,6 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
|
||||||
// CHECK: %int0 = torch.constant.int 0
|
// CHECK: %int0 = torch.constant.int 0
|
||||||
// CHECK: %int1 = torch.constant.int 1
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
// CHECK: %int2 = torch.constant.int 2
|
// CHECK: %int2 = torch.constant.int 2
|
||||||
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
|
|
||||||
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
|
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
|
||||||
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
|
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
|
||||||
// CHECK: %c0 = arith.constant 0 : index
|
// CHECK: %c0 = arith.constant 0 : index
|
||||||
|
|
Loading…
Reference in New Issue