mirror of https://github.com/llvm/torch-mlir
Improve diagnostic handler
It wasn't printing notes or putting the "error:" in front.pull/217/head
parent
2453805f7f
commit
d50ea8d31e
|
@ -55,23 +55,47 @@ static MlirModule createEmptyModule(MlirContext context) {
|
|||
return mlirModuleCreateEmpty(loc);
|
||||
}
|
||||
|
||||
static std::string
|
||||
stringifyMlirDiagnosticSeverity(MlirDiagnosticSeverity severity) {
|
||||
switch (severity) {
|
||||
case MlirDiagnosticError:
|
||||
return "error";
|
||||
case MlirDiagnosticWarning:
|
||||
return "warning";
|
||||
case MlirDiagnosticNote:
|
||||
return "note";
|
||||
case MlirDiagnosticRemark:
|
||||
return "remark";
|
||||
default:
|
||||
return "<unknown severity>";
|
||||
}
|
||||
}
|
||||
|
||||
static void printDiagnostic(MlirDiagnostic diagnostic) {
|
||||
std::stringstream ss;
|
||||
ss << stringifyMlirDiagnosticSeverity(mlirDiagnosticGetSeverity(diagnostic))
|
||||
<< ": ";
|
||||
auto stringCallback = [](MlirStringRef s, void *stringCallbackUserData) {
|
||||
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
||||
ssp->write(s.data, s.length);
|
||||
};
|
||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
||||
// Use pybind11's print:
|
||||
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||
py::print(ss.str(),
|
||||
py::arg("file") = py::module_::import("sys").attr("stderr"));
|
||||
}
|
||||
|
||||
// Register a diagnostic handler that will redirect output to `sys.stderr`
|
||||
// instead of a C/C++-level file abstraction. This ensures, for example,
|
||||
// that mlir diagnostics emitted are correctly routed in Jupyter notebooks.
|
||||
static void registerPythonSysStderrDiagnosticHandler(MlirContext context) {
|
||||
auto diagnosticHandler = [](MlirDiagnostic diagnostic,
|
||||
void *) -> MlirLogicalResult {
|
||||
std::stringstream ss;
|
||||
auto stringCallback =
|
||||
[](MlirStringRef s, void *stringCallbackUserData) {
|
||||
auto *ssp = static_cast<std::stringstream *>(stringCallbackUserData);
|
||||
ssp->write(s.data, s.length);
|
||||
};
|
||||
mlirDiagnosticPrint(diagnostic, stringCallback, static_cast<void *>(&ss));
|
||||
// Use pybind11's print:
|
||||
// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html
|
||||
using namespace pybind11::literals;
|
||||
py::print(ss.str(), "file"_a = py::module_::import("sys").attr("stderr"));
|
||||
printDiagnostic(diagnostic);
|
||||
for (int i = 0, e = mlirDiagnosticGetNumNotes(diagnostic); i != e; i++) {
|
||||
printDiagnostic(mlirDiagnosticGetNote(diagnostic, i));
|
||||
}
|
||||
return mlirLogicalResultSuccess();
|
||||
};
|
||||
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
|
||||
|
|
Loading…
Reference in New Issue