From ec0f6b4b22cc0e98a7576a0a1504220be667ea6e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 30 Apr 2020 17:14:03 -0700 Subject: [PATCH] Add MLIRContext and ModuleOp python bindings with asm parse/print and diagnostics. --- python/npcomp/exp/extractor.py | 6 +- python/npcomp/mlir_ir.cpp | 256 +++++++++++++++++++++++++++++- python/npcomp/mlir_ir_test.py | 32 ++++ python/npcomp/native.cpp | 4 +- python/npcomp/pybind_utils.cpp | 8 +- python/npcomp/pybind_utils.h | 8 +- python/npcomp/utils/test_utils.py | 67 +++++--- python/run_tests.py | 3 +- 8 files changed, 341 insertions(+), 43 deletions(-) create mode 100644 python/npcomp/mlir_ir_test.py diff --git a/python/npcomp/exp/extractor.py b/python/npcomp/exp/extractor.py index 245607a30..b3d7a4dff 100644 --- a/python/npcomp/exp/extractor.py +++ b/python/npcomp/exp/extractor.py @@ -45,7 +45,7 @@ class EmitterRegistry: # Emit op. mlir_m = pft.mlir_module op_result_types = [mlir_m.make_type("tensor<*x!numpy.any_dtype>")] - op_result = edsc.op("numpy.generic_ufunc", op_inputs, op_result_types, + op_result = edsc.op("numpy.tmp_generic_ufunc", op_inputs, op_result_types, ufunc_name=mlir_m.stringAttr(function_name)) # Wrap returns. @@ -104,8 +104,8 @@ class PyFuncTrace: >>> print(pft.mlir_module.get_ir().strip()) module { func @simple_mul(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor { - %0 = "numpy.generic_ufunc"(%arg0, %arg1) {ufunc_name = "numpy.multiply"} : (tensor, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> - %1 = "numpy.generic_ufunc"(%0, %arg0) {ufunc_name = "numpy.add"} : (tensor<*x!numpy.any_dtype>, tensor) -> tensor<*x!numpy.any_dtype> + %0 = "numpy.tmp_generic_ufunc"(%arg0, %arg1) {ufunc_name = "numpy.multiply"} : (tensor, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> + %1 = "numpy.tmp_generic_ufunc"(%0, %arg0) {ufunc_name = "numpy.add"} : (tensor<*x!numpy.any_dtype>, tensor) -> tensor<*x!numpy.any_dtype> %2 = "numpy.narrow"(%1) : (tensor<*x!numpy.any_dtype>) -> tensor return %2 : tensor } diff --git a/python/npcomp/mlir_ir.cpp b/python/npcomp/mlir_ir.cpp index eb3c4ff53..4826bf8b4 100644 --- a/python/npcomp/mlir_ir.cpp +++ b/python/npcomp/mlir_ir.cpp @@ -6,16 +6,264 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include +#include "pybind_utils.h" -namespace py = pybind11; +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" namespace mlir { +//===----------------------------------------------------------------------===// +// Forward declarations +//===----------------------------------------------------------------------===// +struct PyContext; + +/// Parses an MLIR module from a string. +/// For maximum efficiency, the `contents` should be zero terminated. +static OwningModuleRef parseMLIRModuleFromString(StringRef contents, + MLIRContext *context); + +//===----------------------------------------------------------------------===// +// Diagnostics +//===----------------------------------------------------------------------===// + +/// RAII class to capture diagnostics for later reporting back to the python +/// layer. +class DiagnosticCapture { +public: + DiagnosticCapture(mlir::MLIRContext *mlir_context) + : mlir_context(mlir_context) { + handler_id = mlir_context->getDiagEngine().registerHandler( + [&](Diagnostic &d) -> LogicalResult { + diagnostics.push_back(std::move(d)); + return success(); + }); + } + + ~DiagnosticCapture() { + if (mlir_context) { + mlir_context->getDiagEngine().eraseHandler(handler_id); + } + } + DiagnosticCapture(DiagnosticCapture &&other) { + mlir_context = other.mlir_context; + diagnostics.swap(other.diagnostics); + handler_id = other.handler_id; + other.mlir_context = nullptr; + } + + std::vector &getDiagnostics() { return diagnostics; } + + // Consumes/clears diagnostics. + std::string consumeDiagnosticsAsString(const char *error_message); + void clearDiagnostics() { diagnostics.clear(); } + +private: + MLIRContext *mlir_context; + std::vector diagnostics; + mlir::DiagnosticEngine::HandlerID handler_id; +}; + +/// Wrapper around Module, capturing a PyContext reference. +struct PyModuleOp { + static void bind(py::module m) { + py::class_(m, "ModuleOp") + .def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false, + py::arg("pretty") = false, py::arg("large_element_limit") = -1); + } + + std::string toAsm(bool enableDebugInfo, bool prettyForm, + int64_t largeElementLimit) { + // Print to asm. + std::string asmOutput; + llvm::raw_string_ostream sout(asmOutput); + OpPrintingFlags printFlags; + if (enableDebugInfo) { + printFlags.enableDebugInfo(prettyForm); + } + if (largeElementLimit >= 0) { + printFlags.elideLargeElementsAttrs(largeElementLimit); + } + module_op.print(sout, printFlags); + return sout.str(); + } + + std::shared_ptr context; + ModuleOp module_op; +}; + +/// Wrapper around MLIRContext. +/// Unlike most, this is enforced to be a shared_ptr since arbitrary other +/// types can capture it. +struct PyContext : std::enable_shared_from_this { + static void bind(py::module m) { + py::class_>(m, "MLIRContext") + .def(py::init<>([]() { + // Need explicit make_shared to avoid UB with enable_shared_from_this. + return std::make_shared(); + })) + .def("new_module", + [&](PyContext &context) -> PyModuleOp { + return PyModuleOp{context.shared_from_this()}; + }) + .def("parse_asm", &PyContext::parseAsm); + } + + PyModuleOp parseAsm(const std::string &asm_text) { + // Arrange to get a view that includes a terminating null to avoid + // additional copy. + // TODO: Consider using the buffer protocol to access and avoid more copies. + const char *asm_chars = asm_text.c_str(); + StringRef asm_sr(asm_chars, asm_text.size() + 1); + + // TODO: Output non failure diagnostics (somewhere) + DiagnosticCapture diag_capture(&context); + auto module_ref = parseMLIRModuleFromString(asm_sr, &context); + if (!module_ref) { + throw py::raiseValueError( + diag_capture.consumeDiagnosticsAsString("Error parsing ASM")); + } + return PyModuleOp{shared_from_this(), module_ref.release()}; + } + + MLIRContext context; +}; + void defineMlirIrModule(py::module m) { m.doc() = "Python bindings for constructs in the mlir/IR library"; + + PyContext::bind(m); + PyModuleOp::bind(m); +} + +//===----------------------------------------------------------------------===// +// Detail definitions +//===----------------------------------------------------------------------===// + +static OwningModuleRef parseMLIRModuleFromString(StringRef contents, + MLIRContext *context) { + std::unique_ptr contents_buffer; + if (contents.back() == 0) { + // If it has a nul terminator, just use as-is. + contents_buffer = llvm::MemoryBuffer::getMemBuffer(contents.drop_back()); + } else { + // Otherwise, make a copy. + contents_buffer = llvm::MemoryBuffer::getMemBufferCopy(contents, "EMBED"); + } + + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc()); + OwningModuleRef mlir_module = parseSourceFile(source_mgr, context); + return mlir_module; +} + +// Custom location printer that prints prettier, multi-line file output +// suitable for human readable error messages. The standard printer just prints +// a long nested expression not particularly human friendly). Note that there +// is a location pretty printer in the MLIR AsmPrinter. It is private and +// doesn't do any path shortening, which seems to make long Python stack traces +// a bit easier to scan. +// TODO: Upstream this. +void printLocation(Location loc, raw_ostream &out) { + switch (loc->getKind()) { + case StandardAttributes::OpaqueLocation: + printLocation(loc.cast().getFallbackLocation(), out); + break; + case StandardAttributes::UnknownLocation: + out << " [unknown location]\n"; + break; + case StandardAttributes::FileLineColLocation: { + auto line_col_loc = loc.cast(); + StringRef this_filename = line_col_loc.getFilename(); + auto slash_pos = this_filename.find_last_of("/\\"); + // We print both the basename and extended names with a structure like + // `foo.py:35:4`. Even though technically the line/col + // information is redundant to include in both names, having it on both + // makes it easier to paste the paths into an editor and jump to the exact + // location. + std::string line_col_suffix = ":" + std::to_string(line_col_loc.getLine()) + + ":" + + std::to_string(line_col_loc.getColumn()); + bool has_basename = false; + StringRef basename = this_filename; + if (slash_pos != StringRef::npos) { + has_basename = true; + basename = this_filename.substr(slash_pos + 1); + } + out << " at: " << basename << line_col_suffix; + if (has_basename) { + StringRef extended_name = this_filename; + // Print out two tabs, as basenames usually vary in length by more than + // one tab width. + out << "\t\t( " << extended_name << line_col_suffix << " )"; + } + out << "\n"; + break; + } + case StandardAttributes::NameLocation: { + auto nameLoc = loc.cast(); + out << " @'" << nameLoc.getName() << "':\n"; + auto childLoc = nameLoc.getChildLoc(); + if (!childLoc.isa()) { + out << "(...\n"; + printLocation(childLoc, out); + out << ")\n"; + } + break; + } + case StandardAttributes::CallSiteLocation: { + auto call_site = loc.cast(); + printLocation(call_site.getCaller(), out); + printLocation(call_site.getCallee(), out); + break; + } + } +} + +std::string +DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) { + std::string s; + llvm::raw_string_ostream sout(s); + bool first = true; + if (error_message) { + sout << error_message; + first = false; + } + for (auto &d : diagnostics) { + if (!first) { + sout << "\n\n"; + } else { + first = false; + } + + switch (d.getSeverity()) { + case DiagnosticSeverity::Note: + sout << "[NOTE]"; + break; + case DiagnosticSeverity::Warning: + sout << "[WARNING]"; + break; + case DiagnosticSeverity::Error: + sout << "[ERROR]"; + break; + case DiagnosticSeverity::Remark: + sout << "[REMARK]"; + break; + default: + sout << "[UNKNOWN]"; + } + // Message. + sout << ": " << d << "\n"; + printLocation(d.getLocation(), sout); + } + + diagnostics.clear(); + return sout.str(); } } // namespace mlir diff --git a/python/npcomp/mlir_ir_test.py b/python/npcomp/mlir_ir_test.py new file mode 100644 index 000000000..0777c33c4 --- /dev/null +++ b/python/npcomp/mlir_ir_test.py @@ -0,0 +1,32 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Test for the MLIR IR Python bindings""" + +from npcomp.native.mlir import ir +from npcomp.utils import test_utils + +test_utils.start_filecheck_test() +c = ir.MLIRContext() + +# CHECK-LABEL: module @parseSuccess +m = c.parse_asm(r""" +module @parseSuccess { + func @f() { + return + } +} +""") +print(m.to_asm()) + +# CHECK-LABEL: PARSE_FAILURE +print("PARSE_FAILURE") +try: + m = c.parse_asm("{{ILLEGAL SYNTAX}}") +except ValueError as e: + # CHECK: [ERROR]: expected operation name in quotes + print(e) + + +test_utils.end_filecheck_test(__file__) diff --git a/python/npcomp/native.cpp b/python/npcomp/native.cpp index 191cb2f62..5372b33ae 100644 --- a/python/npcomp/native.cpp +++ b/python/npcomp/native.cpp @@ -34,7 +34,7 @@ void defineLLVMModule(pybind11::module m) { if (found_it == options_map.end()) { std::string message = "Unknown LLVM option: "; message.append(name); - throw raiseValueError(message.c_str()); + throw py::raiseValueError(message.c_str()); } std::string value_sr = value ? *value : ""; @@ -48,7 +48,7 @@ void defineLLVMModule(pybind11::module m) { if (found_it == options_map.end()) { std::string message = "Unknown LLVM option: "; message.append(name); - throw raiseValueError(message.c_str()); + throw py::raiseValueError(message.c_str()); } found_it->getValue()->setDefault(); }, diff --git a/python/npcomp/pybind_utils.cpp b/python/npcomp/pybind_utils.cpp index 523745da5..afcbfc281 100644 --- a/python/npcomp/pybind_utils.cpp +++ b/python/npcomp/pybind_utils.cpp @@ -8,9 +8,7 @@ #include "pybind_utils.h" -namespace mlir { -namespace npcomp { -namespace python { +namespace pybind11 { pybind11::error_already_set raisePyError(PyObject *exc_class, const char *message) { @@ -18,6 +16,4 @@ pybind11::error_already_set raisePyError(PyObject *exc_class, return pybind11::error_already_set(); } -} // namespace python -} // namespace npcomp -} // namespace mlir +} // namespace pybind11 diff --git a/python/npcomp/pybind_utils.h b/python/npcomp/pybind_utils.h index 94ff4d9c0..4b05bec05 100644 --- a/python/npcomp/pybind_utils.h +++ b/python/npcomp/pybind_utils.h @@ -25,9 +25,7 @@ struct type_caster> : optional_caster> {}; } // namespace detail } // namespace pybind11 -namespace mlir { -namespace npcomp { -namespace python { +namespace pybind11 { /// Raises a python exception with the given message. /// Correct usage: @@ -49,6 +47,4 @@ inline pybind11::error_already_set raiseValueError(const std::string &message) { return raisePyError(PyExc_ValueError, message.c_str()); } -} // namespace python -} // namespace npcomp -} // namespace mlir +} // namespace pybind11 diff --git a/python/npcomp/utils/test_utils.py b/python/npcomp/utils/test_utils.py index 8efdda8f6..7427c3585 100644 --- a/python/npcomp/utils/test_utils.py +++ b/python/npcomp/utils/test_utils.py @@ -8,6 +8,46 @@ import os import subprocess import sys +_disable_var = "NPCOMP_DISABLE_FILECHECK" +_filecheck_binary_var = "FILECHECK_BINARY" + +def is_filecheck_disabled(): + return _disable_var in os.environ + + +def start_filecheck_test(): + if is_filecheck_disabled(): + print("WARNING:FileCheck disabled due to", _disable_var, + "in the environment", file=sys.stderr) + return + global _redirect_io + global _redirect_context + _redirect_io = io.StringIO() + _redirect_context = contextlib.redirect_stdout(_redirect_io) + _redirect_context.__enter__() + + +def end_filecheck_test(main_file): + if is_filecheck_disabled(): return + global _redirect_io + global _redirect_context + _redirect_context.__exit__(None, None, None) + _redirect_context = None + _redirect_io.flush() + filecheck_input = _redirect_io.getvalue() + _redirect_io = None + filecheck_binary = "FileCheck" + if _filecheck_binary_var in os.environ: + filecheck_binary = os.environ[_filecheck_binary_var] + print("Using FileCheck binary", filecheck_binary, + "(customize by setting", _filecheck_binary_var, ")", file=sys.stderr) + filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"] + print("LAUNCHING FILECHECK:", filecheck_args, file=sys.stderr) + p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE) + p.communicate(filecheck_input.encode("UTF-8")) + sys.exit(p.returncode) + + def run_under_filecheck(main_file, callback, disable_filecheck=False): """Runs a callback under a FileCheck sub-process. @@ -20,29 +60,14 @@ def run_under_filecheck(main_file, callback, disable_filecheck=False): callback: The no-argument callback to invoke. disable_filecheck: Whether to disable filecheck. """ - disable_var = "NPCOMP_DISABLE_FILECHECK" - filecheck_binary_var = "FILECHECK_BINARY" - if "NPCOMP_DISABLE_FILECHECK" in os.environ: - print("WARNING:FileCheck disabled due to", disable_var, + if disable_filecheck or is_filecheck_disabled(): + print("WARNING:FileCheck disabled due to", _disable_var, "in the environment", file=sys.stderr) - disable_filecheck = True - if disable_filecheck: callback() sys.exit(0) - # Redirect through FileCheck - filecheck_capture_io = io.StringIO() - with contextlib.redirect_stdout(filecheck_capture_io): + try: + start_filecheck_test() callback() - filecheck_capture_io.flush() - filecheck_input = filecheck_capture_io.getvalue() - filecheck_binary = "FileCheck" - if filecheck_binary_var in os.environ: - filecheck_binary = os.environ[filecheck_binary_var] - print("Using FileCheck binary", filecheck_binary, - "(customize by setting", filecheck_binary_var, ")", file=sys.stderr) - filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"] - print("LAUNCHING FILECHECK:", filecheck_args, file=sys.stderr) - p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE) - p.communicate(filecheck_input.encode("UTF-8")) - sys.exit(p.returncode) + finally: + end_filecheck_test(main_file) diff --git a/python/run_tests.py b/python/run_tests.py index 8cae87083..dbebd71ad 100755 --- a/python/run_tests.py +++ b/python/run_tests.py @@ -6,6 +6,7 @@ import sys TEST_MODULES = ( + "npcomp.mlir_ir_test", "npcomp.edsc_test", "npcomp.tracing.context", "npcomp.tracing.mlir_trace", @@ -15,7 +16,7 @@ TEST_MODULES = ( ) # Compute PYTHONPATH for sub processes. -DIRSEP = ":" if os.path.pathsep == "/" else ";" +DIRSEP = ":" if os.path.sep == "/" else ";" PYTHONPATH = os.path.abspath(os.path.dirname(__file__)) if "PYTHONPATH" in os.environ: PYTHONPATH = PYTHONPATH + DIRSEP + os.environ["PYTHONPATH"]