mirror of https://github.com/llvm/torch-mlir
Add MLIR diagnostic handler that prints to `sys.stderr`.
This is needed so that output shows up properly in a Jupyter notebook.pull/160/head
parent
572163dfde
commit
498979ad28
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir-c/Bindings/Python/Interop.h"
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "mlir-c/Registration.h"
|
#include "mlir-c/Registration.h"
|
||||||
#include "npcomp-c/Registration.h"
|
#include "npcomp-c/Registration.h"
|
||||||
|
|
||||||
|
@ -56,6 +57,32 @@ static MlirModule createEmptyModule(MlirContext context) {
|
||||||
return mlirModuleCreateEmpty(loc);
|
return mlirModuleCreateEmpty(loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
||||||
|
// instead of a C/C++-level file abstraction. This ensures, for example,
|
||||||
|
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
||||||
|
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||||
|
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
||||||
|
void *) -> MlirLogicalResult {
|
||||||
|
std::stringstream ss;
|
||||||
|
auto stringCallback =
|
||||||
|
[](MlirStringRef s, void *stringCallbackUserData) {
|
||||||
|
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
||||||
|
ssp->write(s.data, s.length);
|
||||||
|
};
|
||||||
|
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
||||||
|
// Use pybind11's print:
|
||||||
|
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||||
|
using namespace pybind11::literals;
|
||||||
|
py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr"));
|
||||||
|
return mlirLogicalResultSuccess();
|
||||||
|
};
|
||||||
|
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||||
|
context, diagnosticHandler, nullptr, [](void *) { return; });
|
||||||
|
// Ignore the ID. We intend to keep this handler for the entire lifetime
|
||||||
|
// of this context.
|
||||||
|
(void)id;
|
||||||
|
}
|
||||||
|
|
||||||
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
: contextObj(createPythonContextIfNone(std::move(contextObj))),
|
: contextObj(createPythonContextIfNone(std::move(contextObj))),
|
||||||
context(castPythonObjectToMlirContext(this->contextObj)),
|
context(castPythonObjectToMlirContext(this->contextObj)),
|
||||||
|
@ -67,6 +94,8 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
||||||
mlirRegisterAllDialects(context);
|
mlirRegisterAllDialects(context);
|
||||||
npcompRegisterAllDialects(context);
|
npcompRegisterAllDialects(context);
|
||||||
|
|
||||||
|
registerPythonSysStderrDiagnosticHandler(context);
|
||||||
|
|
||||||
// Terminator will always be the first op of an empty module.
|
// Terminator will always be the first op of an empty module.
|
||||||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue