LowerToBackendContract: Explicitly error out on unimplemented operator (#1947)

* LowerToBackendContract: Explicitly error out on unimplemented operator

But only reject torch.operator when results are invalid.
Otherwise it might be a custom op that the backend supports.
pull/1955/head
Matthias Gehre 2023-03-20 16:27:08 +01:00 committed by GitHub
parent 6eeed46060
commit aa5bcb3cf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 0 deletions

View File

@ -197,6 +197,24 @@ static bool satisfiesBackendContract(ModuleOp module,
if (walkResult0.wasInterrupted())
return false;
// Check for unimplemented operators first to give more direct diagnostics.
walkResult0 = module.walk([&](Torch::OperatorOp op) {
if (llvm::all_of(op.getResults(), [&op](auto res) {
return succeeded(
checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false));
})) {
return WalkResult::advance();
}
if (actuallyEmitDiagnostics) {
op->emitError("unsupported by backend contract: Unimplemented operator '"
+ op.getName() + "'");
}
return WalkResult::interrupt();
});
if (walkResult0.wasInterrupted())
return false;
// Check all the types of all Value's in the program and the legality of all
// the ops.
//

View File

@ -0,0 +1,10 @@
// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s
func.func @forward(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor {
%none = torch.constant.none
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,5],f32> to !torch.vtensor<*,f32>
%1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32>
// expected-error @+1 {{unsupported by backend contract: Unimplemented operator 'an.unimplemented.op'}}
%2 = torch.operator "an.unimplemented.op"(%1, %1, %none) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>, !torch.none) -> !torch.tensor
%3 = torch.copy.to_vtensor %2 : !torch.vtensor
return %3 : !torch.vtensor
}