Improve diagnostic handler

It wasn't printing notes or putting the "error:" in front.
pull/217/head
Sean Silva 2021-05-18 12:17:44 -07:00
parent 2453805f7f
commit d50ea8d31e
1 changed files with 35 additions and 11 deletions

View File

@ -55,23 +55,47 @@ static MlirModule createEmptyModule(MlirContext context) {
return mlirModuleCreateEmpty(loc); return mlirModuleCreateEmpty(loc);
} }
// Register a diagnostic handler that will redirect output to `sys.stderr` static std::string
// instead of a C/C++-level file abstraction. This ensures, for example, stringifyMlirDiagnosticSeverity(MlirDiagnosticSeverity severity) {
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks. switch (severity) {
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) { case MlirDiagnosticError:
auto diagnosticHandler = [](MlirDiagnostic diagnostic, return "error";
void *) -> MlirLogicalResult { case MlirDiagnosticWarning:
return "warning";
case MlirDiagnosticNote:
return "note";
case MlirDiagnosticRemark:
return "remark";
default:
return "<unknown severity>";
}
}
static void printDiagnostic(MlirDiagnostic diagnostic) {
std::stringstream ss; std::stringstream ss;
auto stringCallback = ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
[](MlirStringRef s, void *stringCallbackUserData) { << ": ";
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) {
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData); auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
ssp->write(s.data, s.length); ssp->write(s.data, s.length);
}; };
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss)); mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
// Use pybind11's print: // Use pybind11's print:
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html // https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
using namespace pybind11::literals; py::print(ss.str(),
py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr")); py::arg("file") = py::module_::import("sys").attr("stderr"));
}
// Register a diagnostic handler that will redirect output to `sys.stderr`
// instead of a C/C++-level file abstraction. This ensures, for example,
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
void *) -> MlirLogicalResult {
printDiagnostic(diagnostic);
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
}
return mlirLogicalResultSuccess(); return mlirLogicalResultSuccess();
}; };
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(