mirror of https://github.com/llvm/torch-mlir
Adapt to RegionBranchPoint changes: https://reviews.llvm.org/D159116
parent
86b16ee84c
commit
37edb9f26e
|
@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() {
|
||||||
// PrimLoopOp
|
// PrimLoopOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OperandRange
|
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
|
||||||
PrimLoopOp::getEntrySuccessorOperands(std::optional<unsigned int> index) {
|
assert(point == getRegion());
|
||||||
assert(index.has_value() && index.value() == 0);
|
|
||||||
return getIterArgsInit();
|
return getIterArgsInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrimLoopOp::getSuccessorRegions(
|
void PrimLoopOp::getSuccessorRegions(
|
||||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
|
Region ®ion = getRegion();
|
||||||
if (!index.has_value()) {
|
if (!point.getRegionOrNull()) {
|
||||||
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
|
regions.emplace_back(®ion, region.getArguments().slice(1));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
assert(*index == 0);
|
assert(point == region);
|
||||||
regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1));
|
regions.emplace_back(®ion, region.getArguments().slice(1));
|
||||||
regions.emplace_back(getResults());
|
regions.emplace_back(getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() {
|
||||||
// PrimLoopConditionOp
|
// PrimLoopConditionOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(
|
MutableOperandRange
|
||||||
std::optional<unsigned> index) {
|
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
|
||||||
// Pass all operands except the condition to the successor which is the
|
// Pass all operands except the condition to the successor which is the
|
||||||
// parent loop op.
|
// parent loop op.
|
||||||
return getIterArgsMutable();
|
return getIterArgsMutable();
|
||||||
|
@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) {
|
||||||
p.printOptionalAttrDict((*this)->getAttrs());
|
p.printOptionalAttrDict((*this)->getAttrs());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
|
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
|
||||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
// The `then` and the `else` region branch back to the parent operation.
|
// The `then` and the `else` region branch back to the parent operation.
|
||||||
if (index.has_value()) {
|
if (point.getRegionOrNull()) {
|
||||||
regions.push_back(RegionSuccessor(getResults()));
|
regions.push_back(RegionSuccessor(getResults()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2768,26 +2767,26 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
template <typename CalculateOp>
|
template <typename CalculateOp>
|
||||||
static void
|
static void
|
||||||
getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional<unsigned> index,
|
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
|
||||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
if (!index.has_value()) {
|
if (!point.getRegionOrNull()) {
|
||||||
// First thing the op does is branch into the calculation.
|
// First thing the op does is branch into the calculation.
|
||||||
regions.emplace_back(&op.getCalculation());
|
regions.emplace_back(&op.getCalculation());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (*index == 0) {
|
if (point == op.getBody()) {
|
||||||
// Body returns control to the outer op, passing through results.
|
// Body returns control to the outer op, passing through results.
|
||||||
regions.emplace_back(op.getResults());
|
regions.emplace_back(op.getResults());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
assert(*index == 1);
|
assert(point == op.getCalculation());
|
||||||
// Calculation branches to the body.
|
// Calculation branches to the body.
|
||||||
regions.emplace_back(&op.getBody());
|
regions.emplace_back(&op.getBody());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShapeCalculateOp::getSuccessorRegions(
|
void ShapeCalculateOp::getSuccessorRegions(
|
||||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
getSuccessorRegionsForCalculateOp(*this, index, regions);
|
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2795,8 +2794,8 @@ void ShapeCalculateOp::getSuccessorRegions(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void DtypeCalculateOp::getSuccessorRegions(
|
void DtypeCalculateOp::getSuccessorRegions(
|
||||||
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
|
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
getSuccessorRegionsForCalculateOp(*this, index, regions);
|
getSuccessorRegionsForCalculateOp(*this, point, regions);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2804,7 +2803,7 @@ void DtypeCalculateOp::getSuccessorRegions(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
||||||
std::optional<unsigned> index) {
|
RegionBranchPoint point) {
|
||||||
// The shape operands don't get forwarded to the body.
|
// The shape operands don't get forwarded to the body.
|
||||||
// MutableOperandRange always has an owning operation, even if empty, so
|
// MutableOperandRange always has an owning operation, even if empty, so
|
||||||
// create a 0-length range.
|
// create a 0-length range.
|
||||||
|
@ -2823,7 +2822,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
|
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
|
||||||
std::optional<unsigned> index) {
|
RegionBranchPoint point) {
|
||||||
// The dtype operands don't get forwarded to the body.
|
// The dtype operands don't get forwarded to the body.
|
||||||
// MutableOperandRange always has an owning operation, even if empty, so
|
// MutableOperandRange always has an owning operation, even if empty, so
|
||||||
// create a 0-length range.
|
// create a 0-length range.
|
||||||
|
|
Loading…
Reference in New Issue