mirror of https://github.com/llvm/torch-mlir
torch: [nfc] use `WalkResult::isInterrupted()` instead of booleans (#1081)
An upstream MLIR bug (that was recently fixed) caused the result to be ignored for Region- and Block-visitor functions. Now that the bug is fixed, we don't need an auxiliary boolean to track whether the visitor function has succeeded.pull/1089/head snapshot-20220720.539
parent
21f905afbe
commit
e06ee08506
|
@ -464,26 +464,17 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
|
|||
// Verify that `func` conforms to the subset of allowable method bodies
|
||||
// that we can convert.
|
||||
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
|
||||
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
|
||||
LogicalResult ret = success();
|
||||
func.walk([&](Block *block) {
|
||||
for (Value arg : block->getArguments()) {
|
||||
if (failed(verifyNnModuleValueUses(arg))) {
|
||||
ret = failure();
|
||||
auto walkResult = func.walk([&](Block *block) {
|
||||
for (Value arg : block->getArguments())
|
||||
if (failed(verifyNnModuleValueUses(arg)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
for (Operation &op : *block) {
|
||||
for (Value result : op.getResults()) {
|
||||
if (failed(verifyNnModuleValueUses(result))) {
|
||||
ret = failure();
|
||||
for (Operation &op : *block)
|
||||
for (Value result : op.getResults())
|
||||
if (failed(verifyNnModuleValueUses(result)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return ret;
|
||||
return success(!walkResult.wasInterrupted());
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
|
|
|
@ -30,28 +30,20 @@ class VerifyConversionToValueSemanticsPass
|
|||
: public VerifyConversionToValueSemanticsBase<
|
||||
VerifyConversionToValueSemanticsPass> {
|
||||
void runOnOperation() override {
|
||||
bool didFail = false;
|
||||
auto walkResult = getOperation().walk([&](Block *block) {
|
||||
for (BlockArgument arg : block->getArguments()) {
|
||||
if (failed(checkValueType(block->getParentOp(), arg))) {
|
||||
didFail = true;
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (failed(checkValueType(block->getParentOp(), arg)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation &op : *block) {
|
||||
for (OpResult result : op.getResults()) {
|
||||
if (failed(checkValueType(&op, result))) {
|
||||
didFail = true;
|
||||
for (Operation &op : *block)
|
||||
for (OpResult result : op.getResults())
|
||||
if (failed(checkValueType(&op, result)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (didFail || walkResult.wasInterrupted())
|
||||
if (walkResult.wasInterrupted())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -42,19 +42,13 @@ class VerifyInvariantsBeforeBackendLoweringPass
|
|||
: public VerifyInvariantsBeforeBackendLoweringBase<
|
||||
VerifyInvariantsBeforeBackendLoweringPass> {
|
||||
void runOnOperation() override {
|
||||
// TODO: It seems that the walkers over blocks are not correctly
|
||||
// propagating `walkResult.wasInterrupted()` so use a manual `didFail`
|
||||
// boolean.
|
||||
bool didFail = false;
|
||||
getOperation().walk([&](Block *block) {
|
||||
auto walkResult = getOperation().walk([&](Block *block) {
|
||||
// Check invariants on all the Value's in the program.
|
||||
// That is, check all BlockArgument's and OpResult's.
|
||||
for (BlockArgument arg : block->getArguments()) {
|
||||
if (failed(checkValueInvariants(block->getParentOp(), arg))) {
|
||||
didFail = true;
|
||||
for (BlockArgument arg : block->getArguments())
|
||||
if (failed(checkValueInvariants(block->getParentOp(), arg)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation &op : *block) {
|
||||
if (isa<Torch::OperatorOp>(op)) {
|
||||
op.emitError()
|
||||
|
@ -62,19 +56,16 @@ class VerifyInvariantsBeforeBackendLoweringPass
|
|||
.attachNote()
|
||||
.append("this is likely due to a missing op that needs to be "
|
||||
"generated by torch_ods_gen.py");
|
||||
didFail = true;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
for (OpResult result : op.getResults()) {
|
||||
if (failed(checkValueInvariants(&op, result))) {
|
||||
didFail = true;
|
||||
|
||||
for (OpResult result : op.getResults())
|
||||
if (failed(checkValueInvariants(&op, result)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (didFail)
|
||||
if (walkResult.wasInterrupted())
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue