From 498979ad2886398f35bef100280eff0ea6fa086b Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 17 Feb 2021 15:50:13 -0800 Subject: [PATCH] Add MLIR diagnostic handler that prints to `sys.stderr`. This is needed so that output shows up properly in a Jupyter notebook. --- .../pytorch/csrc/builder/module_builder.cpp | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/frontends/pytorch/csrc/builder/module_builder.cpp b/frontends/pytorch/csrc/builder/module_builder.cpp index 13e718378..213835ea9 100644 --- a/frontends/pytorch/csrc/builder/module_builder.cpp +++ b/frontends/pytorch/csrc/builder/module_builder.cpp @@ -13,6 +13,7 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/Registration.h" #include "npcomp-c/Registration.h" @@ -56,6 +57,32 @@ static MlirModule createEmptyModule(MlirContext context) { return mlirModuleCreateEmpty(loc); } +// Register a diagnostic handler that will redirect output to `sys.stderr` +// instead of a C/C++-level file abstraction. This ensures, for example, +// that mlir diagnostics emitted are correctly routed in Jupyter notebooks. +static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { + auto diagnosticHandler = [](MlirDiagnostic diagnostic, + void *) -> MlirLogicalResult { + std::stringstream ss; + auto stringCallback = + [](MlirStringRef s, void *stringCallbackUserData) { + auto *ssp = static_cast(stringCallbackUserData); + ssp->write(s.data, s.length); + }; + mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); + // Use pybind11's print: + // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html + using namespace pybind11::literals; + py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr")); + return mlirLogicalResultSuccess(); + }; + MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( + context, diagnosticHandler, nullptr, [](void *) { return; }); + // Ignore the ID. We intend to keep this handler for the entire lifetime + // of this context. + (void)id; +} + ModuleBuilder::ModuleBuilder(pybind11::object contextObj) : contextObj(createPythonContextIfNone(std::move(contextObj))), context(castPythonObjectToMlirContext(this->contextObj)), @@ -67,6 +94,8 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj) mlirRegisterAllDialects(context); npcompRegisterAllDialects(context); + registerPythonSysStderrDiagnosticHandler(context); + // Terminator will always be the first op of an empty module. terminator = mlirBlockGetFirstOperation(getBodyBlock()); }