diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 8ca07604d..6ca39f379 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -464,26 +464,17 @@ static LogicalResult verifyNnModuleValueUses(Value value) { // Verify that `func` conforms to the subset of allowable method bodies // that we can convert. static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) { - // TODO: Investingate why WalkResult::interrupt() doesn't propagate properly. - LogicalResult ret = success(); - func.walk([&](Block *block) { - for (Value arg : block->getArguments()) { - if (failed(verifyNnModuleValueUses(arg))) { - ret = failure(); + auto walkResult = func.walk([&](Block *block) { + for (Value arg : block->getArguments()) + if (failed(verifyNnModuleValueUses(arg))) return WalkResult::interrupt(); - } - } - for (Operation &op : *block) { - for (Value result : op.getResults()) { - if (failed(verifyNnModuleValueUses(result))) { - ret = failure(); + for (Operation &op : *block) + for (Value result : op.getResults()) + if (failed(verifyNnModuleValueUses(result))) return WalkResult::interrupt(); - } - } - } return WalkResult::advance(); }); - return ret; + return success(!walkResult.wasInterrupted()); } static LogicalResult diff --git a/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp b/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp index 0bbce7320..7a055ece3 100644 --- a/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp @@ -30,28 +30,20 @@ class VerifyConversionToValueSemanticsPass : public VerifyConversionToValueSemanticsBase< VerifyConversionToValueSemanticsPass> { void runOnOperation() override { - bool didFail = false; auto walkResult = getOperation().walk([&](Block *block) { - for (BlockArgument arg : block->getArguments()) { - if (failed(checkValueType(block->getParentOp(), arg))) { - didFail = true; + for (BlockArgument arg : block->getArguments()) + if (failed(checkValueType(block->getParentOp(), arg))) return WalkResult::interrupt(); - } - } - for (Operation &op : *block) { - for (OpResult result : op.getResults()) { - if (failed(checkValueType(&op, result))) { - didFail = true; + for (Operation &op : *block) + for (OpResult result : op.getResults()) + if (failed(checkValueType(&op, result))) return WalkResult::interrupt(); - } - } - } return WalkResult::advance(); }); - if (didFail || walkResult.wasInterrupted()) + if (walkResult.wasInterrupted()) signalPassFailure(); } }; diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp index 5e0406ee5..0b822ec1a 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyInvariantsBeforeBackendLowering.cpp @@ -42,19 +42,13 @@ class VerifyInvariantsBeforeBackendLoweringPass : public VerifyInvariantsBeforeBackendLoweringBase< VerifyInvariantsBeforeBackendLoweringPass> { void runOnOperation() override { - // TODO: It seems that the walkers over blocks are not correctly - // propagating `walkResult.wasInterrupted()` so use a manual `didFail` - // boolean. - bool didFail = false; - getOperation().walk([&](Block *block) { + auto walkResult = getOperation().walk([&](Block *block) { // Check invariants on all the Value's in the program. // That is, check all BlockArgument's and OpResult's. - for (BlockArgument arg : block->getArguments()) { - if (failed(checkValueInvariants(block->getParentOp(), arg))) { - didFail = true; + for (BlockArgument arg : block->getArguments()) + if (failed(checkValueInvariants(block->getParentOp(), arg))) return WalkResult::interrupt(); - } - } + for (Operation &op : *block) { if (isa(op)) { op.emitError() @@ -62,19 +56,16 @@ class VerifyInvariantsBeforeBackendLoweringPass .attachNote() .append("this is likely due to a missing op that needs to be " "generated by torch_ods_gen.py"); - didFail = true; return WalkResult::interrupt(); } - for (OpResult result : op.getResults()) { - if (failed(checkValueInvariants(&op, result))) { - didFail = true; + + for (OpResult result : op.getResults()) + if (failed(checkValueInvariants(&op, result))) return WalkResult::interrupt(); - } - } } return WalkResult::advance(); }); - if (didFail) + if (walkResult.wasInterrupted()) return signalPassFailure(); } };