Add MLIR diagnostic handler that prints to `sys.stderr`.

This is needed so that output shows up properly in a Jupyter notebook.
pull/160/head
Sean Silva 2021-02-17 15:50:13 -08:00
parent 572163dfde
commit 498979ad28
1 changed files with 29 additions and 0 deletions

View File

@ -13,6 +13,7 @@
#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h" #include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/Registration.h" #include "mlir-c/Registration.h"
#include "npcomp-c/Registration.h" #include "npcomp-c/Registration.h"
@ -56,6 +57,32 @@ static MlirModule createEmptyModule(MlirContext context) {
return mlirModuleCreateEmpty(loc); return mlirModuleCreateEmpty(loc);
} }
// Register a diagnostic handler that will redirect output to `sys.stderr`
// instead of a C/C++-level file abstraction. This ensures, for example,
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
void *) -> MlirLogicalResult {
std::stringstream ss;
auto stringCallback =
[](MlirStringRef s, void *stringCallbackUserData) {
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
ssp->write(s.data, s.length);
};
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
// Use pybind11's print:
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
using namespace pybind11::literals;
py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr"));
return mlirLogicalResultSuccess();
};
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
context, diagnosticHandler, nullptr, [](void *) { return; });
// Ignore the ID. We intend to keep this handler for the entire lifetime
// of this context.
(void)id;
}
ModuleBuilder::ModuleBuilder(pybind11::object contextObj) ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
: contextObj(createPythonContextIfNone(std::move(contextObj))), : contextObj(createPythonContextIfNone(std::move(contextObj))),
context(castPythonObjectToMlirContext(this->contextObj)), context(castPythonObjectToMlirContext(this->contextObj)),
@ -67,6 +94,8 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
mlirRegisterAllDialects(context); mlirRegisterAllDialects(context);
npcompRegisterAllDialects(context); npcompRegisterAllDialects(context);
registerPythonSysStderrDiagnosticHandler(context);
// Terminator will always be the first op of an empty module. // Terminator will always be the first op of an empty module.
terminator = mlirBlockGetFirstOperation(getBodyBlock()); terminator = mlirBlockGetFirstOperation(getBodyBlock());
} }