mirror of https://github.com/llvm/torch-mlir
Adapt to RegionBranchPoint changes: https://reviews.llvm.org/D159116
parent
86b16ee84c
commit
37edb9f26e
|
@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() {
|
|||
// PrimLoopOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperandRange
|
||||
PrimLoopOp::getEntrySuccessorOperands(std::optional<unsigned int> index) {
|
||||
assert(index.has_value() && index.value() == 0);
|
||||
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||
assert(point == getRegion());
|
||||
return getIterArgsInit();
|
||||
}
|
||||
|
||||
void PrimLoopOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
|
||||
if (!index.has_value()) {
|
||||
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
Region ®ion = getRegion();
|
||||
if (!point.getRegionOrNull()) {
|
||||
regions.emplace_back(®ion, region.getArguments().slice(1));
|
||||
return;
|
||||
}
|
||||
assert(*index == 0);
|
||||
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
|
||||
assert(point == region);
|
||||
regions.emplace_back(®ion, region.getArguments().slice(1));
|
||||
regions.emplace_back(getResults());
|
||||
}
|
||||
|
||||
|
@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() {
|
|||
// PrimLoopConditionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(
|
||||
std::optional<unsigned> index) {
|
||||
MutableOperandRange
|
||||
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
|
||||
// Pass all operands except the condition to the successor which is the
|
||||
// parent loop op.
|
||||
return getIterArgsMutable();
|
||||
|
@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) {
|
|||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
}
|
||||
|
||||
void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
|
||||
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
// The `then` and the `else` region branch back to the parent operation.
|
||||
if (index.has_value()) {
|
||||
if (point.getRegionOrNull()) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
|
@ -2768,26 +2767,26 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
template <typename CalculateOp>
|
||||
static void
|
||||
getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional<unsigned> index,
|
||||
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (!index.has_value()) {
|
||||
if (!point.getRegionOrNull()) {
|
||||
// First thing the op does is branch into the calculation.
|
||||
regions.emplace_back(&op.getCalculation());
|
||||
return;
|
||||
}
|
||||
if (*index == 0) {
|
||||
if (point == op.getBody()) {
|
||||
// Body returns control to the outer op, passing through results.
|
||||
regions.emplace_back(op.getResults());
|
||||
return;
|
||||
}
|
||||
assert(*index == 1);
|
||||
assert(point == op.getCalculation());
|
||||
// Calculation branches to the body.
|
||||
regions.emplace_back(&op.getBody());
|
||||
}
|
||||
|
||||
void ShapeCalculateOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
getSuccessorRegionsForCalculateOp(*this, index, regions);
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2795,8 +2794,8 @@ void ShapeCalculateOp::getSuccessorRegions(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DtypeCalculateOp::getSuccessorRegions(
|
||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
getSuccessorRegionsForCalculateOp(*this, index, regions);
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2804,7 +2803,7 @@ void DtypeCalculateOp::getSuccessorRegions(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
||||
std::optional<unsigned> index) {
|
||||
RegionBranchPoint point) {
|
||||
// The shape operands don't get forwarded to the body.
|
||||
// MutableOperandRange always has an owning operation, even if empty, so
|
||||
// create a 0-length range.
|
||||
|
@ -2823,7 +2822,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
|
||||
std::optional<unsigned> index) {
|
||||
RegionBranchPoint point) {
|
||||
// The dtype operands don't get forwarded to the body.
|
||||
// MutableOperandRange always has an owning operation, even if empty, so
|
||||
// create a 0-length range.
|
||||
|
|
Loading…
Reference in New Issue