From d50ea8d31e7b641dccf090bc8e5b62c7680b8126 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 18 May 2021 12:17:44 -0700 Subject: [PATCH] Improve diagnostic handler It wasn't printing notes or putting the "error:" in front. --- .../pytorch/csrc/builder/module_builder.cpp | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/frontends/pytorch/csrc/builder/module_builder.cpp b/frontends/pytorch/csrc/builder/module_builder.cpp index 54e5454d8..6c34c4e4f 100644 --- a/frontends/pytorch/csrc/builder/module_builder.cpp +++ b/frontends/pytorch/csrc/builder/module_builder.cpp @@ -55,23 +55,47 @@ static MlirModule createEmptyModule(MlirContext context) { return mlirModuleCreateEmpty(loc); } +static std::string +stringifyMlirDiagnosticSeverity(MlirDiagnosticSeverity severity) { + switch (severity) { + case MlirDiagnosticError: + return "error"; + case MlirDiagnosticWarning: + return "warning"; + case MlirDiagnosticNote: + return "note"; + case MlirDiagnosticRemark: + return "remark"; + default: + return ""; + } +} + +static void printDiagnostic(MlirDiagnostic diagnostic) { + std::stringstream ss; + ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic)) + << ": "; + auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) { + auto *ssp = static_cast(stringCallbackUserData); + ssp->write(s.data, s.length); + }; + mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); + // Use pybind11's print: + // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html + py::print(ss.str(), + py::arg("file") = py::module_::import("sys").attr("stderr")); +} + // Register a diagnostic handler that will redirect output to `sys.stderr` // instead of a C/C++-level file abstraction. This ensures, for example, // that mlir diagnostics emitted are correctly routed in Jupyter notebooks. static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { auto diagnosticHandler = [](MlirDiagnostic diagnostic, void *) -> MlirLogicalResult { - std::stringstream ss; - auto stringCallback = - [](MlirStringRef s, void *stringCallbackUserData) { - auto *ssp = static_cast(stringCallbackUserData); - ssp->write(s.data, s.length); - }; - mlirDiagnosticPrint(diagnostic, stringCallback, static_cast(&ss)); - // Use pybind11's print: - // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html - using namespace pybind11::literals; - py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr")); + printDiagnostic(diagnostic); + for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) { + printDiagnostic(mlirDiagnosticGetNote(diagnostic, i)); + } return mlirLogicalResultSuccess(); }; MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(