mirror of https://github.com/llvm/torch-mlir
Add MLIR diagnostic handler that prints to `sys.stderr`.
This is needed so that output shows up properly in a Jupyter notebook.pull/160/head
parent
572163dfde
commit
498979ad28
|
@ -13,6 +13,7 @@
|
|||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
|
||||
|
@ -56,6 +57,32 @@ static MlirModule createEmptyModule(MlirContext context) {
|
|||
return mlirModuleCreateEmpty(loc);
|
||||
}
|
||||
|
||||
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
||||
// instead of a C/C++-level file abstraction. This ensures, for example,
|
||||
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
||||
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
||||
void *) -> MlirLogicalResult {
|
||||
std::stringstream ss;
|
||||
auto stringCallback =
|
||||
[](MlirStringRef s, void *stringCallbackUserData) {
|
||||
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
||||
ssp->write(s.data, s.length);
|
||||
};
|
||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
||||
// Use pybind11's print:
|
||||
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||
using namespace pybind11::literals;
|
||||
py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr"));
|
||||
return mlirLogicalResultSuccess();
|
||||
};
|
||||
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||
context, diagnosticHandler, nullptr, [](void *) { return; });
|
||||
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
||||
// of this context.
|
||||
(void)id;
|
||||
}
|
||||
|
||||
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||
: contextObj(createPythonContextIfNone(std::move(contextObj))),
|
||||
context(castPythonObjectToMlirContext(this->contextObj)),
|
||||
|
@ -67,6 +94,8 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
mlirRegisterAllDialects(context);
|
||||
npcompRegisterAllDialects(context);
|
||||
|
||||
registerPythonSysStderrDiagnosticHandler(context);
|
||||
|
||||
// Terminator will always be the first op of an empty module.
|
||||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue