mirror of https://github.com/llvm/torch-mlir
Shore up error reporting for TorchScript import.
This code was not exception safe -- it would leave an operation unattached to anything, which breaks MLIR's C++ data structure invariants (e.g. it cannot safely erase ops). Also, print out both the exception and any diagnostics, since they can both contain useful information.pull/217/head
parent
d50ea8d31e
commit
0c89296075
|
@ -18,6 +18,8 @@
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/Types.h"
|
||||||
|
|
||||||
|
#include "caffe2/core/scope_guard.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
// Hashing functionality for IValue's.
|
// Hashing functionality for IValue's.
|
||||||
|
@ -146,7 +148,8 @@ private:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
MlirValue
|
||||||
|
IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
// TODO: Can we do better?
|
// TODO: Can we do better?
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
|
@ -170,6 +173,10 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
|
auto inserter = caffe2::MakeGuard([&]() {
|
||||||
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
|
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
||||||
|
});
|
||||||
|
|
||||||
if (!rootModuleName.has_value()) {
|
if (!rootModuleName.has_value()) {
|
||||||
rootModuleName = moduleTypeName;
|
rootModuleName = moduleTypeName;
|
||||||
|
@ -198,8 +205,6 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
}
|
}
|
||||||
|
|
||||||
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
createMlirOperationAtEnd(nnModuleBody, "torch.nn_module_terminator", loc);
|
||||||
mlirBlockInsertOwnedOperationBefore(
|
|
||||||
importBlock, mlirBlockGetTerminator(importBlock), nnModule);
|
|
||||||
return mlirOperationGetResult(nnModule, 0);
|
return mlirOperationGetResult(nnModule, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,10 @@ class RefBackendTestConfig(TestConfig):
|
||||||
mb.import_module(scripted._c, class_annotator)
|
mb.import_module(scripted._c, class_annotator)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"""
|
raise Exception(f"""
|
||||||
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with the following diagnostics:
|
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with:
|
||||||
|
Exception:
|
||||||
|
{e}
|
||||||
|
Diagnostics:
|
||||||
{sys.stderr.getvalue()}
|
{sys.stderr.getvalue()}
|
||||||
""") from None
|
""") from None
|
||||||
finally:
|
finally:
|
||||||
|
|
Loading…
Reference in New Issue