torch: [nfc] use `WalkResult::isInterrupted()` instead of booleans (#1081)

An upstream MLIR bug (that was recently fixed) caused the result to be
ignored for Region- and Block-visitor functions.  Now that the bug is
fixed, we don't need an auxiliary boolean to track whether the visitor
function has succeeded.
pull/1089/head snapshot-20220720.539
Ashay Rane 2022-07-19 10:17:57 -07:00 committed by GitHub
parent 21f905afbe
commit e06ee08506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 47 deletions

View File

@ -464,26 +464,17 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
// Verify that `func` conforms to the subset of allowable method bodies
// that we can convert.
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
LogicalResult ret = success();
func.walk([&](Block *block) {
for (Value arg : block->getArguments()) {
if (failed(verifyNnModuleValueUses(arg))) {
ret = failure();
auto walkResult = func.walk([&](Block *block) {
for (Value arg : block->getArguments())
if (failed(verifyNnModuleValueUses(arg)))
return WalkResult::interrupt();
}
}
for (Operation &op : *block) {
for (Value result : op.getResults()) {
if (failed(verifyNnModuleValueUses(result))) {
ret = failure();
for (Operation &op : *block)
for (Value result : op.getResults())
if (failed(verifyNnModuleValueUses(result)))
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return ret;
return success(!walkResult.wasInterrupted());
}
static LogicalResult

View File

@ -30,28 +30,20 @@ class VerifyConversionToValueSemanticsPass
: public VerifyConversionToValueSemanticsBase<
VerifyConversionToValueSemanticsPass> {
void runOnOperation() override {
bool didFail = false;
auto walkResult = getOperation().walk([&](Block *block) {
for (BlockArgument arg : block->getArguments()) {
if (failed(checkValueType(block->getParentOp(), arg))) {
didFail = true;
for (BlockArgument arg : block->getArguments())
if (failed(checkValueType(block->getParentOp(), arg)))
return WalkResult::interrupt();
}
}
for (Operation &op : *block) {
for (OpResult result : op.getResults()) {
if (failed(checkValueType(&op, result))) {
didFail = true;
for (Operation &op : *block)
for (OpResult result : op.getResults())
if (failed(checkValueType(&op, result)))
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (didFail || walkResult.wasInterrupted())
if (walkResult.wasInterrupted())
signalPassFailure();
}
};

View File

@ -42,19 +42,13 @@ class VerifyInvariantsBeforeBackendLoweringPass
: public VerifyInvariantsBeforeBackendLoweringBase<
VerifyInvariantsBeforeBackendLoweringPass> {
void runOnOperation() override {
// TODO: It seems that the walkers over blocks are not correctly
// propagating `walkResult.wasInterrupted()` so use a manual `didFail`
// boolean.
bool didFail = false;
getOperation().walk([&](Block *block) {
auto walkResult = getOperation().walk([&](Block *block) {
// Check invariants on all the Value's in the program.
// That is, check all BlockArgument's and OpResult's.
for (BlockArgument arg : block->getArguments()) {
if (failed(checkValueInvariants(block->getParentOp(), arg))) {
didFail = true;
for (BlockArgument arg : block->getArguments())
if (failed(checkValueInvariants(block->getParentOp(), arg)))
return WalkResult::interrupt();
}
}
for (Operation &op : *block) {
if (isa<Torch::OperatorOp>(op)) {
op.emitError()
@ -62,19 +56,16 @@ class VerifyInvariantsBeforeBackendLoweringPass
.attachNote()
.append("this is likely due to a missing op that needs to be "
"generated by torch_ods_gen.py");
didFail = true;
return WalkResult::interrupt();
}
for (OpResult result : op.getResults()) {
if (failed(checkValueInvariants(&op, result))) {
didFail = true;
for (OpResult result : op.getResults())
if (failed(checkValueInvariants(&op, result)))
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (didFail)
if (walkResult.wasInterrupted())
return signalPassFailure();
}
};