mirror of https://github.com/llvm/torch-mlir
Add MLIRContext and ModuleOp python bindings with asm parse/print and diagnostics.
parent
67d38db1e2
commit
ec0f6b4b22
|
@ -45,7 +45,7 @@ class EmitterRegistry:
|
|||
# Emit op.
|
||||
mlir_m = pft.mlir_module
|
||||
op_result_types = [mlir_m.make_type("tensor<*x!numpy.any_dtype>")]
|
||||
op_result = edsc.op("numpy.generic_ufunc", op_inputs, op_result_types,
|
||||
op_result = edsc.op("numpy.tmp_generic_ufunc", op_inputs, op_result_types,
|
||||
ufunc_name=mlir_m.stringAttr(function_name))
|
||||
|
||||
# Wrap returns.
|
||||
|
@ -104,8 +104,8 @@ class PyFuncTrace:
|
|||
>>> print(pft.mlir_module.get_ir().strip())
|
||||
module {
|
||||
func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
|
||||
%0 = "numpy.generic_ufunc"(%arg0, %arg1) {ufunc_name = "numpy.multiply"} : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
%1 = "numpy.generic_ufunc"(%0, %arg0) {ufunc_name = "numpy.add"} : (tensor<*x!numpy.any_dtype>, tensor<?x4xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
%0 = "numpy.tmp_generic_ufunc"(%arg0, %arg1) {ufunc_name = "numpy.multiply"} : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
%1 = "numpy.tmp_generic_ufunc"(%0, %arg0) {ufunc_name = "numpy.add"} : (tensor<*x!numpy.any_dtype>, tensor<?x4xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
%2 = "numpy.narrow"(%1) : (tensor<*x!numpy.any_dtype>) -> tensor<?x4xf32>
|
||||
return %2 : tensor<?x4xf32>
|
||||
}
|
||||
|
|
|
@ -6,16 +6,264 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include "pybind_utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Forward declarations
|
||||
//===----------------------------------------------------------------------===//
|
||||
struct PyContext;
|
||||
|
||||
/// Parses an MLIR module from a string.
|
||||
/// For maximum efficiency, the `contents` should be zero terminated.
|
||||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||
MLIRContext *context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Diagnostics
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// RAII class to capture diagnostics for later reporting back to the python
|
||||
/// layer.
|
||||
class DiagnosticCapture {
|
||||
public:
|
||||
DiagnosticCapture(mlir::MLIRContext *mlir_context)
|
||||
: mlir_context(mlir_context) {
|
||||
handler_id = mlir_context->getDiagEngine().registerHandler(
|
||||
[&](Diagnostic &d) -> LogicalResult {
|
||||
diagnostics.push_back(std::move(d));
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
~DiagnosticCapture() {
|
||||
if (mlir_context) {
|
||||
mlir_context->getDiagEngine().eraseHandler(handler_id);
|
||||
}
|
||||
}
|
||||
DiagnosticCapture(DiagnosticCapture &&other) {
|
||||
mlir_context = other.mlir_context;
|
||||
diagnostics.swap(other.diagnostics);
|
||||
handler_id = other.handler_id;
|
||||
other.mlir_context = nullptr;
|
||||
}
|
||||
|
||||
std::vector<mlir::Diagnostic> &getDiagnostics() { return diagnostics; }
|
||||
|
||||
// Consumes/clears diagnostics.
|
||||
std::string consumeDiagnosticsAsString(const char *error_message);
|
||||
void clearDiagnostics() { diagnostics.clear(); }
|
||||
|
||||
private:
|
||||
MLIRContext *mlir_context;
|
||||
std::vector<mlir::Diagnostic> diagnostics;
|
||||
mlir::DiagnosticEngine::HandlerID handler_id;
|
||||
};
|
||||
|
||||
/// Wrapper around Module, capturing a PyContext reference.
|
||||
struct PyModuleOp {
|
||||
static void bind(py::module m) {
|
||||
py::class_<PyModuleOp>(m, "ModuleOp")
|
||||
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
||||
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
|
||||
}
|
||||
|
||||
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
||||
int64_t largeElementLimit) {
|
||||
// Print to asm.
|
||||
std::string asmOutput;
|
||||
llvm::raw_string_ostream sout(asmOutput);
|
||||
OpPrintingFlags printFlags;
|
||||
if (enableDebugInfo) {
|
||||
printFlags.enableDebugInfo(prettyForm);
|
||||
}
|
||||
if (largeElementLimit >= 0) {
|
||||
printFlags.elideLargeElementsAttrs(largeElementLimit);
|
||||
}
|
||||
module_op.print(sout, printFlags);
|
||||
return sout.str();
|
||||
}
|
||||
|
||||
std::shared_ptr<PyContext> context;
|
||||
ModuleOp module_op;
|
||||
};
|
||||
|
||||
/// Wrapper around MLIRContext.
|
||||
/// Unlike most, this is enforced to be a shared_ptr since arbitrary other
|
||||
/// types can capture it.
|
||||
struct PyContext : std::enable_shared_from_this<PyContext> {
|
||||
static void bind(py::module m) {
|
||||
py::class_<PyContext, std::shared_ptr<PyContext>>(m, "MLIRContext")
|
||||
.def(py::init<>([]() {
|
||||
// Need explicit make_shared to avoid UB with enable_shared_from_this.
|
||||
return std::make_shared<PyContext>();
|
||||
}))
|
||||
.def("new_module",
|
||||
[&](PyContext &context) -> PyModuleOp {
|
||||
return PyModuleOp{context.shared_from_this()};
|
||||
})
|
||||
.def("parse_asm", &PyContext::parseAsm);
|
||||
}
|
||||
|
||||
PyModuleOp parseAsm(const std::string &asm_text) {
|
||||
// Arrange to get a view that includes a terminating null to avoid
|
||||
// additional copy.
|
||||
// TODO: Consider using the buffer protocol to access and avoid more copies.
|
||||
const char *asm_chars = asm_text.c_str();
|
||||
StringRef asm_sr(asm_chars, asm_text.size() + 1);
|
||||
|
||||
// TODO: Output non failure diagnostics (somewhere)
|
||||
DiagnosticCapture diag_capture(&context);
|
||||
auto module_ref = parseMLIRModuleFromString(asm_sr, &context);
|
||||
if (!module_ref) {
|
||||
throw py::raiseValueError(
|
||||
diag_capture.consumeDiagnosticsAsString("Error parsing ASM"));
|
||||
}
|
||||
return PyModuleOp{shared_from_this(), module_ref.release()};
|
||||
}
|
||||
|
||||
MLIRContext context;
|
||||
};
|
||||
|
||||
void defineMlirIrModule(py::module m) {
|
||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||
|
||||
PyContext::bind(m);
|
||||
PyModuleOp::bind(m);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Detail definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||
MLIRContext *context) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> contents_buffer;
|
||||
if (contents.back() == 0) {
|
||||
// If it has a nul terminator, just use as-is.
|
||||
contents_buffer = llvm::MemoryBuffer::getMemBuffer(contents.drop_back());
|
||||
} else {
|
||||
// Otherwise, make a copy.
|
||||
contents_buffer = llvm::MemoryBuffer::getMemBufferCopy(contents, "EMBED");
|
||||
}
|
||||
|
||||
llvm::SourceMgr source_mgr;
|
||||
source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc());
|
||||
OwningModuleRef mlir_module = parseSourceFile(source_mgr, context);
|
||||
return mlir_module;
|
||||
}
|
||||
|
||||
// Custom location printer that prints prettier, multi-line file output
|
||||
// suitable for human readable error messages. The standard printer just prints
|
||||
// a long nested expression not particularly human friendly). Note that there
|
||||
// is a location pretty printer in the MLIR AsmPrinter. It is private and
|
||||
// doesn't do any path shortening, which seems to make long Python stack traces
|
||||
// a bit easier to scan.
|
||||
// TODO: Upstream this.
|
||||
void printLocation(Location loc, raw_ostream &out) {
|
||||
switch (loc->getKind()) {
|
||||
case StandardAttributes::OpaqueLocation:
|
||||
printLocation(loc.cast<OpaqueLoc>().getFallbackLocation(), out);
|
||||
break;
|
||||
case StandardAttributes::UnknownLocation:
|
||||
out << " [unknown location]\n";
|
||||
break;
|
||||
case StandardAttributes::FileLineColLocation: {
|
||||
auto line_col_loc = loc.cast<FileLineColLoc>();
|
||||
StringRef this_filename = line_col_loc.getFilename();
|
||||
auto slash_pos = this_filename.find_last_of("/\\");
|
||||
// We print both the basename and extended names with a structure like
|
||||
// `foo.py:35:4`. Even though technically the line/col
|
||||
// information is redundant to include in both names, having it on both
|
||||
// makes it easier to paste the paths into an editor and jump to the exact
|
||||
// location.
|
||||
std::string line_col_suffix = ":" + std::to_string(line_col_loc.getLine()) +
|
||||
":" +
|
||||
std::to_string(line_col_loc.getColumn());
|
||||
bool has_basename = false;
|
||||
StringRef basename = this_filename;
|
||||
if (slash_pos != StringRef::npos) {
|
||||
has_basename = true;
|
||||
basename = this_filename.substr(slash_pos + 1);
|
||||
}
|
||||
out << " at: " << basename << line_col_suffix;
|
||||
if (has_basename) {
|
||||
StringRef extended_name = this_filename;
|
||||
// Print out two tabs, as basenames usually vary in length by more than
|
||||
// one tab width.
|
||||
out << "\t\t( " << extended_name << line_col_suffix << " )";
|
||||
}
|
||||
out << "\n";
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::NameLocation: {
|
||||
auto nameLoc = loc.cast<NameLoc>();
|
||||
out << " @'" << nameLoc.getName() << "':\n";
|
||||
auto childLoc = nameLoc.getChildLoc();
|
||||
if (!childLoc.isa<UnknownLoc>()) {
|
||||
out << "(...\n";
|
||||
printLocation(childLoc, out);
|
||||
out << ")\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::CallSiteLocation: {
|
||||
auto call_site = loc.cast<CallSiteLoc>();
|
||||
printLocation(call_site.getCaller(), out);
|
||||
printLocation(call_site.getCallee(), out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
||||
std::string s;
|
||||
llvm::raw_string_ostream sout(s);
|
||||
bool first = true;
|
||||
if (error_message) {
|
||||
sout << error_message;
|
||||
first = false;
|
||||
}
|
||||
for (auto &d : diagnostics) {
|
||||
if (!first) {
|
||||
sout << "\n\n";
|
||||
} else {
|
||||
first = false;
|
||||
}
|
||||
|
||||
switch (d.getSeverity()) {
|
||||
case DiagnosticSeverity::Note:
|
||||
sout << "[NOTE]";
|
||||
break;
|
||||
case DiagnosticSeverity::Warning:
|
||||
sout << "[WARNING]";
|
||||
break;
|
||||
case DiagnosticSeverity::Error:
|
||||
sout << "[ERROR]";
|
||||
break;
|
||||
case DiagnosticSeverity::Remark:
|
||||
sout << "[REMARK]";
|
||||
break;
|
||||
default:
|
||||
sout << "[UNKNOWN]";
|
||||
}
|
||||
// Message.
|
||||
sout << ": " << d << "\n";
|
||||
printLocation(d.getLocation(), sout);
|
||||
}
|
||||
|
||||
diagnostics.clear();
|
||||
return sout.str();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""Test for the MLIR IR Python bindings"""
|
||||
|
||||
from npcomp.native.mlir import ir
|
||||
from npcomp.utils import test_utils
|
||||
|
||||
test_utils.start_filecheck_test()
|
||||
c = ir.MLIRContext()
|
||||
|
||||
# CHECK-LABEL: module @parseSuccess
|
||||
m = c.parse_asm(r"""
|
||||
module @parseSuccess {
|
||||
func @f() {
|
||||
return
|
||||
}
|
||||
}
|
||||
""")
|
||||
print(m.to_asm())
|
||||
|
||||
# CHECK-LABEL: PARSE_FAILURE
|
||||
print("PARSE_FAILURE")
|
||||
try:
|
||||
m = c.parse_asm("{{ILLEGAL SYNTAX}}")
|
||||
except ValueError as e:
|
||||
# CHECK: [ERROR]: expected operation name in quotes
|
||||
print(e)
|
||||
|
||||
|
||||
test_utils.end_filecheck_test(__file__)
|
|
@ -34,7 +34,7 @@ void defineLLVMModule(pybind11::module m) {
|
|||
if (found_it == options_map.end()) {
|
||||
std::string message = "Unknown LLVM option: ";
|
||||
message.append(name);
|
||||
throw raiseValueError(message.c_str());
|
||||
throw py::raiseValueError(message.c_str());
|
||||
}
|
||||
|
||||
std::string value_sr = value ? *value : "";
|
||||
|
@ -48,7 +48,7 @@ void defineLLVMModule(pybind11::module m) {
|
|||
if (found_it == options_map.end()) {
|
||||
std::string message = "Unknown LLVM option: ";
|
||||
message.append(name);
|
||||
throw raiseValueError(message.c_str());
|
||||
throw py::raiseValueError(message.c_str());
|
||||
}
|
||||
found_it->getValue()->setDefault();
|
||||
},
|
||||
|
|
|
@ -8,9 +8,7 @@
|
|||
|
||||
#include "pybind_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
namespace pybind11 {
|
||||
|
||||
pybind11::error_already_set raisePyError(PyObject *exc_class,
|
||||
const char *message) {
|
||||
|
@ -18,6 +16,4 @@ pybind11::error_already_set raisePyError(PyObject *exc_class,
|
|||
return pybind11::error_already_set();
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
||||
} // namespace pybind11
|
||||
|
|
|
@ -25,9 +25,7 @@ struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
|
|||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
namespace mlir {
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
namespace pybind11 {
|
||||
|
||||
/// Raises a python exception with the given message.
|
||||
/// Correct usage:
|
||||
|
@ -49,6 +47,4 @@ inline pybind11::error_already_set raiseValueError(const std::string &message) {
|
|||
return raisePyError(PyExc_ValueError, message.c_str());
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
||||
} // namespace pybind11
|
||||
|
|
|
@ -8,6 +8,46 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
|
||||
_disable_var = "NPCOMP_DISABLE_FILECHECK"
|
||||
_filecheck_binary_var = "FILECHECK_BINARY"
|
||||
|
||||
def is_filecheck_disabled():
|
||||
return _disable_var in os.environ
|
||||
|
||||
|
||||
def start_filecheck_test():
|
||||
if is_filecheck_disabled():
|
||||
print("WARNING:FileCheck disabled due to", _disable_var,
|
||||
"in the environment", file=sys.stderr)
|
||||
return
|
||||
global _redirect_io
|
||||
global _redirect_context
|
||||
_redirect_io = io.StringIO()
|
||||
_redirect_context = contextlib.redirect_stdout(_redirect_io)
|
||||
_redirect_context.__enter__()
|
||||
|
||||
|
||||
def end_filecheck_test(main_file):
|
||||
if is_filecheck_disabled(): return
|
||||
global _redirect_io
|
||||
global _redirect_context
|
||||
_redirect_context.__exit__(None, None, None)
|
||||
_redirect_context = None
|
||||
_redirect_io.flush()
|
||||
filecheck_input = _redirect_io.getvalue()
|
||||
_redirect_io = None
|
||||
filecheck_binary = "FileCheck"
|
||||
if _filecheck_binary_var in os.environ:
|
||||
filecheck_binary = os.environ[_filecheck_binary_var]
|
||||
print("Using FileCheck binary", filecheck_binary,
|
||||
"(customize by setting", _filecheck_binary_var, ")", file=sys.stderr)
|
||||
filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"]
|
||||
print("LAUNCHING FILECHECK:", filecheck_args, file=sys.stderr)
|
||||
p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE)
|
||||
p.communicate(filecheck_input.encode("UTF-8"))
|
||||
sys.exit(p.returncode)
|
||||
|
||||
|
||||
def run_under_filecheck(main_file, callback, disable_filecheck=False):
|
||||
"""Runs a callback under a FileCheck sub-process.
|
||||
|
||||
|
@ -20,29 +60,14 @@ def run_under_filecheck(main_file, callback, disable_filecheck=False):
|
|||
callback: The no-argument callback to invoke.
|
||||
disable_filecheck: Whether to disable filecheck.
|
||||
"""
|
||||
disable_var = "NPCOMP_DISABLE_FILECHECK"
|
||||
filecheck_binary_var = "FILECHECK_BINARY"
|
||||
if "NPCOMP_DISABLE_FILECHECK" in os.environ:
|
||||
print("WARNING:FileCheck disabled due to", disable_var,
|
||||
if disable_filecheck or is_filecheck_disabled():
|
||||
print("WARNING:FileCheck disabled due to", _disable_var,
|
||||
"in the environment", file=sys.stderr)
|
||||
disable_filecheck = True
|
||||
if disable_filecheck:
|
||||
callback()
|
||||
sys.exit(0)
|
||||
|
||||
# Redirect through FileCheck
|
||||
filecheck_capture_io = io.StringIO()
|
||||
with contextlib.redirect_stdout(filecheck_capture_io):
|
||||
try:
|
||||
start_filecheck_test()
|
||||
callback()
|
||||
filecheck_capture_io.flush()
|
||||
filecheck_input = filecheck_capture_io.getvalue()
|
||||
filecheck_binary = "FileCheck"
|
||||
if filecheck_binary_var in os.environ:
|
||||
filecheck_binary = os.environ[filecheck_binary_var]
|
||||
print("Using FileCheck binary", filecheck_binary,
|
||||
"(customize by setting", filecheck_binary_var, ")", file=sys.stderr)
|
||||
filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"]
|
||||
print("LAUNCHING FILECHECK:", filecheck_args, file=sys.stderr)
|
||||
p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE)
|
||||
p.communicate(filecheck_input.encode("UTF-8"))
|
||||
sys.exit(p.returncode)
|
||||
finally:
|
||||
end_filecheck_test(main_file)
|
||||
|
|
|
@ -6,6 +6,7 @@ import sys
|
|||
|
||||
|
||||
TEST_MODULES = (
|
||||
"npcomp.mlir_ir_test",
|
||||
"npcomp.edsc_test",
|
||||
"npcomp.tracing.context",
|
||||
"npcomp.tracing.mlir_trace",
|
||||
|
@ -15,7 +16,7 @@ TEST_MODULES = (
|
|||
)
|
||||
|
||||
# Compute PYTHONPATH for sub processes.
|
||||
DIRSEP = ":" if os.path.pathsep == "/" else ";"
|
||||
DIRSEP = ":" if os.path.sep == "/" else ";"
|
||||
PYTHONPATH = os.path.abspath(os.path.dirname(__file__))
|
||||
if "PYTHONPATH" in os.environ:
|
||||
PYTHONPATH = PYTHONPATH + DIRSEP + os.environ["PYTHONPATH"]
|
||||
|
|
Loading…
Reference in New Issue