Shore up error reporting for TorchScript import.

This code was not exception safe -- it would leave an operation
unattached to anything, which breaks MLIR's C++ data structure
invariants (e.g. it cannot safely erase ops).

Also, print out both the exception and any diagnostics, since they can
both contain useful information.
pull/217/head
Sean Silva 2021-05-18 12:48:22 -07:00
parent d50ea8d31e
commit 0c89296075
2 changed files with 12 additions and 4 deletions

View File

@ -18,6 +18,8 @@
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "caffe2/core/scope_guard.h"
using namespace torch_mlir;
// Hashing functionality for IValue's.
@ -146,7 +148,8 @@ private:
};
} // namespace
MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
MlirValue
IValueImporter::importModule(torch::jit::Module currentModule) {
// TODO: Can we do better?
MlirLocation loc = mlirLocationUnknownGet(context);
@ -170,6 +173,10 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
auto inserter = caffe2::MakeGuard([&]() {
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
});
if (!rootModuleName.has_value()) {
rootModuleName = moduleTypeName;
@ -198,8 +205,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
}
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
mlirBlockInsertOwnedOperationBefore(
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
return mlirOperationGetResult(nnModule, 0);
}

View File

@ -39,7 +39,10 @@ class RefBackendTestConfig(TestConfig):
mb.import_module(scripted._c, class_annotator)
except Exception as e:
raise Exception(f"""
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics:
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with:
Exception:
{e}
Diagnostics:
{sys.stderr.getvalue()}
""") from None
finally: