torch: [nfc] use `WalkResult::isInterrupted()` instead of booleans (#1081)

An upstream MLIR bug (that was recently fixed) caused the result to be
ignored for Region- and Block-visitor functions.  Now that the bug is
fixed, we don't need an auxiliary boolean to track whether the visitor
function has succeeded.
pull/1089/head snapshot-20220720.539
Ashay Rane 2022-07-19 10:17:57 -07:00 committed by GitHub
parent 21f905afbe
commit e06ee08506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 47 deletions

View File

@ -464,26 +464,17 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
// Verify that `func` conforms to the subset of allowable method bodies // Verify that `func` conforms to the subset of allowable method bodies
// that we can convert. // that we can convert.
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) { static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly. auto walkResult = func.walk([&](Block *block) {
LogicalResult ret = success(); for (Value arg : block->getArguments())
func.walk([&](Block *block) { if (failed(verifyNnModuleValueUses(arg)))
for (Value arg : block->getArguments()) {
if (failed(verifyNnModuleValueUses(arg))) {
ret = failure();
return WalkResult::interrupt(); return WalkResult::interrupt();
} for (Operation &op : *block)
} for (Value result : op.getResults())
for (Operation &op : *block) { if (failed(verifyNnModuleValueUses(result)))
for (Value result : op.getResults()) {
if (failed(verifyNnModuleValueUses(result))) {
ret = failure();
return WalkResult::interrupt(); return WalkResult::interrupt();
}
}
}
return WalkResult::advance(); return WalkResult::advance();
}); });
return ret; return success(!walkResult.wasInterrupted());
} }
static LogicalResult static LogicalResult

View File

@ -30,28 +30,20 @@ class VerifyConversionToValueSemanticsPass
: public VerifyConversionToValueSemanticsBase< : public VerifyConversionToValueSemanticsBase<
VerifyConversionToValueSemanticsPass> { VerifyConversionToValueSemanticsPass> {
void runOnOperation() override { void runOnOperation() override {
bool didFail = false;
auto walkResult = getOperation().walk([&](Block *block) { auto walkResult = getOperation().walk([&](Block *block) {
for (BlockArgument arg : block->getArguments()) { for (BlockArgument arg : block->getArguments())
if (failed(checkValueType(block->getParentOp(), arg))) { if (failed(checkValueType(block->getParentOp(), arg)))
didFail = true;
return WalkResult::interrupt(); return WalkResult::interrupt();
}
}
for (Operation &op : *block) { for (Operation &op : *block)
for (OpResult result : op.getResults()) { for (OpResult result : op.getResults())
if (failed(checkValueType(&op, result))) { if (failed(checkValueType(&op, result)))
didFail = true;
return WalkResult::interrupt(); return WalkResult::interrupt();
}
}
}
return WalkResult::advance(); return WalkResult::advance();
}); });
if (didFail || walkResult.wasInterrupted()) if (walkResult.wasInterrupted())
signalPassFailure(); signalPassFailure();
} }
}; };

View File

@ -42,19 +42,13 @@ class VerifyInvariantsBeforeBackendLoweringPass
: public VerifyInvariantsBeforeBackendLoweringBase< : public VerifyInvariantsBeforeBackendLoweringBase<
VerifyInvariantsBeforeBackendLoweringPass> { VerifyInvariantsBeforeBackendLoweringPass> {
void runOnOperation() override { void runOnOperation() override {
// TODO: It seems that the walkers over blocks are not correctly auto walkResult = getOperation().walk([&](Block *block) {
// propagating `walkResult.wasInterrupted()` so use a manual `didFail`
// boolean.
bool didFail = false;
getOperation().walk([&](Block *block) {
// Check invariants on all the Value's in the program. // Check invariants on all the Value's in the program.
// That is, check all BlockArgument's and OpResult's. // That is, check all BlockArgument's and OpResult's.
for (BlockArgument arg : block->getArguments()) { for (BlockArgument arg : block->getArguments())
if (failed(checkValueInvariants(block->getParentOp(), arg))) { if (failed(checkValueInvariants(block->getParentOp(), arg)))
didFail = true;
return WalkResult::interrupt(); return WalkResult::interrupt();
}
}
for (Operation &op : *block) { for (Operation &op : *block) {
if (isa<Torch::OperatorOp>(op)) { if (isa<Torch::OperatorOp>(op)) {
op.emitError() op.emitError()
@ -62,19 +56,16 @@ class VerifyInvariantsBeforeBackendLoweringPass
.attachNote() .attachNote()
.append("this is likely due to a missing op that needs to be " .append("this is likely due to a missing op that needs to be "
"generated by torch_ods_gen.py"); "generated by torch_ods_gen.py");
didFail = true;
return WalkResult::interrupt(); return WalkResult::interrupt();
} }
for (OpResult result : op.getResults()) {
if (failed(checkValueInvariants(&op, result))) { for (OpResult result : op.getResults())
didFail = true; if (failed(checkValueInvariants(&op, result)))
return WalkResult::interrupt(); return WalkResult::interrupt();
} }
}
}
return WalkResult::advance(); return WalkResult::advance();
}); });
if (didFail) if (walkResult.wasInterrupted())
return signalPassFailure(); return signalPassFailure();
} }
}; };