mirror of https://github.com/llvm/torch-mlir
Start splitting Py* types into a header so that further C++ interop can be built.
parent
ec0f6b4b22
commit
c8740fd866
|
@ -1,4 +1,4 @@
|
||||||
//===- mlir_if.cpp - MLIR IR Bindings -------------------------------------===//
|
//===- mlir_ir.cpp - MLIR IR Bindings -------------------------------------===//
|
||||||
//
|
//
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
@ -6,12 +6,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "pybind_utils.h"
|
#include "mlir_ir.h"
|
||||||
|
|
||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Module.h"
|
|
||||||
#include "mlir/Parser.h"
|
#include "mlir/Parser.h"
|
||||||
#include "llvm/Support/SourceMgr.h"
|
#include "llvm/Support/SourceMgr.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
@ -69,82 +67,110 @@ private:
|
||||||
mlir::DiagnosticEngine::HandlerID handler_id;
|
mlir::DiagnosticEngine::HandlerID handler_id;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Wrapper around Module, capturing a PyContext reference.
|
|
||||||
struct PyModuleOp {
|
|
||||||
static void bind(py::module m) {
|
|
||||||
py::class_<PyModuleOp>(m, "ModuleOp")
|
|
||||||
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
|
||||||
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
|
||||||
int64_t largeElementLimit) {
|
|
||||||
// Print to asm.
|
|
||||||
std::string asmOutput;
|
|
||||||
llvm::raw_string_ostream sout(asmOutput);
|
|
||||||
OpPrintingFlags printFlags;
|
|
||||||
if (enableDebugInfo) {
|
|
||||||
printFlags.enableDebugInfo(prettyForm);
|
|
||||||
}
|
|
||||||
if (largeElementLimit >= 0) {
|
|
||||||
printFlags.elideLargeElementsAttrs(largeElementLimit);
|
|
||||||
}
|
|
||||||
module_op.print(sout, printFlags);
|
|
||||||
return sout.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<PyContext> context;
|
|
||||||
ModuleOp module_op;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Wrapper around MLIRContext.
|
|
||||||
/// Unlike most, this is enforced to be a shared_ptr since arbitrary other
|
|
||||||
/// types can capture it.
|
|
||||||
struct PyContext : std::enable_shared_from_this<PyContext> {
|
|
||||||
static void bind(py::module m) {
|
|
||||||
py::class_<PyContext, std::shared_ptr<PyContext>>(m, "MLIRContext")
|
|
||||||
.def(py::init<>([]() {
|
|
||||||
// Need explicit make_shared to avoid UB with enable_shared_from_this.
|
|
||||||
return std::make_shared<PyContext>();
|
|
||||||
}))
|
|
||||||
.def("new_module",
|
|
||||||
[&](PyContext &context) -> PyModuleOp {
|
|
||||||
return PyModuleOp{context.shared_from_this()};
|
|
||||||
})
|
|
||||||
.def("parse_asm", &PyContext::parseAsm);
|
|
||||||
}
|
|
||||||
|
|
||||||
PyModuleOp parseAsm(const std::string &asm_text) {
|
|
||||||
// Arrange to get a view that includes a terminating null to avoid
|
|
||||||
// additional copy.
|
|
||||||
// TODO: Consider using the buffer protocol to access and avoid more copies.
|
|
||||||
const char *asm_chars = asm_text.c_str();
|
|
||||||
StringRef asm_sr(asm_chars, asm_text.size() + 1);
|
|
||||||
|
|
||||||
// TODO: Output non failure diagnostics (somewhere)
|
|
||||||
DiagnosticCapture diag_capture(&context);
|
|
||||||
auto module_ref = parseMLIRModuleFromString(asm_sr, &context);
|
|
||||||
if (!module_ref) {
|
|
||||||
throw py::raiseValueError(
|
|
||||||
diag_capture.consumeDiagnosticsAsString("Error parsing ASM"));
|
|
||||||
}
|
|
||||||
return PyModuleOp{shared_from_this(), module_ref.release()};
|
|
||||||
}
|
|
||||||
|
|
||||||
MLIRContext context;
|
|
||||||
};
|
|
||||||
|
|
||||||
void defineMlirIrModule(py::module m) {
|
void defineMlirIrModule(py::module m) {
|
||||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||||
|
|
||||||
PyContext::bind(m);
|
PyContext::bind(m);
|
||||||
|
PyBaseOperation::bind(m);
|
||||||
PyModuleOp::bind(m);
|
PyModuleOp::bind(m);
|
||||||
|
PyRegionRef::bind(m);
|
||||||
|
PyBaseOpBuilder::bind(m);
|
||||||
|
PyOpBuilder::bind(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Detail definitions
|
// PyContext
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PyContext::bind(py::module m) {
|
||||||
|
py::class_<PyContext, std::shared_ptr<PyContext>>(m, "MLIRContext")
|
||||||
|
.def(py::init<>([]() {
|
||||||
|
// Need explicit make_shared to avoid UB with enable_shared_from_this.
|
||||||
|
return std::make_shared<PyContext>();
|
||||||
|
}))
|
||||||
|
.def("new_module",
|
||||||
|
[&](PyContext &context) -> PyModuleOp {
|
||||||
|
return PyModuleOp(context.shared_from_this(), {});
|
||||||
|
})
|
||||||
|
.def("parse_asm", &PyContext::parseAsm);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyModuleOp PyContext::parseAsm(const std::string &asm_text) {
|
||||||
|
// Arrange to get a view that includes a terminating null to avoid
|
||||||
|
// additional copy.
|
||||||
|
// TODO: Consider using the buffer protocol to access and avoid more copies.
|
||||||
|
const char *asm_chars = asm_text.c_str();
|
||||||
|
StringRef asm_sr(asm_chars, asm_text.size() + 1);
|
||||||
|
|
||||||
|
// TODO: Output non failure diagnostics (somewhere)
|
||||||
|
DiagnosticCapture diag_capture(&context);
|
||||||
|
auto module_ref = parseMLIRModuleFromString(asm_sr, &context);
|
||||||
|
if (!module_ref) {
|
||||||
|
throw py::raiseValueError(
|
||||||
|
diag_capture.consumeDiagnosticsAsString("Error parsing ASM"));
|
||||||
|
}
|
||||||
|
return PyModuleOp{shared_from_this(), module_ref.release()};
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PyBaseOperation
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
PyBaseOperation::~PyBaseOperation() = default;
|
||||||
|
|
||||||
|
void PyBaseOperation::bind(py::module m) {
|
||||||
|
py::class_<PyBaseOperation>(m, "BaseOperation")
|
||||||
|
.def_property_readonly(
|
||||||
|
"name",
|
||||||
|
[](PyBaseOperation *self) {
|
||||||
|
return std::string(self->getOperation()->getName().getStringRef());
|
||||||
|
})
|
||||||
|
.def_property_readonly("is_registered",
|
||||||
|
[](PyBaseOperation *self) {
|
||||||
|
return self->getOperation()->isRegistered();
|
||||||
|
})
|
||||||
|
.def_property_readonly("num_regions",
|
||||||
|
[](PyBaseOperation *self) {
|
||||||
|
return self->getOperation()->getNumRegions();
|
||||||
|
})
|
||||||
|
.def("region", [](PyBaseOperation *self, int index) {
|
||||||
|
auto *op = self->getOperation();
|
||||||
|
if (index < 0 || index >= op->getNumRegions()) {
|
||||||
|
throw py::raisePyError(PyExc_IndexError,
|
||||||
|
"Region index out of bounds");
|
||||||
|
}
|
||||||
|
return PyRegionRef(op->getRegion(index));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PyModuleOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PyModuleOp::bind(py::module m) {
|
||||||
|
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
|
||||||
|
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
||||||
|
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation *PyModuleOp::getOperation() { return moduleOp; }
|
||||||
|
|
||||||
|
std::string PyModuleOp::toAsm(bool enableDebugInfo, bool prettyForm,
|
||||||
|
int64_t largeElementLimit) {
|
||||||
|
// Print to asm.
|
||||||
|
std::string asmOutput;
|
||||||
|
llvm::raw_string_ostream sout(asmOutput);
|
||||||
|
OpPrintingFlags printFlags;
|
||||||
|
if (enableDebugInfo) {
|
||||||
|
printFlags.enableDebugInfo(prettyForm);
|
||||||
|
}
|
||||||
|
if (largeElementLimit >= 0) {
|
||||||
|
printFlags.elideLargeElementsAttrs(largeElementLimit);
|
||||||
|
}
|
||||||
|
moduleOp.print(sout, printFlags);
|
||||||
|
return sout.str();
|
||||||
|
}
|
||||||
|
|
||||||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
std::unique_ptr<llvm::MemoryBuffer> contents_buffer;
|
std::unique_ptr<llvm::MemoryBuffer> contents_buffer;
|
||||||
|
@ -225,6 +251,10 @@ void printLocation(Location loc, raw_ostream &out) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DiagnosticCapture
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
std::string
|
std::string
|
||||||
DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
||||||
std::string s;
|
std::string s;
|
||||||
|
@ -266,4 +296,29 @@ DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
||||||
return sout.str();
|
return sout.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PyRegionRef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PyRegionRef::bind(py::module m) {
|
||||||
|
py::class_<PyRegionRef>(m, "RegionRef");
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OpBuilder implementations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
PyBaseOpBuilder::~PyBaseOpBuilder() = default;
|
||||||
|
PyOpBuilder::~PyOpBuilder() = default;
|
||||||
|
OpBuilder &PyOpBuilder::getBuilder() { return builder; }
|
||||||
|
|
||||||
|
void PyBaseOpBuilder::bind(py::module m) {
|
||||||
|
py::class_<PyBaseOpBuilder>(m, "BaseOpBuilder");
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyOpBuilder::bind(py::module m) {
|
||||||
|
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
|
||||||
|
.def(py::init<PyContext &>());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
//===- mlir_ir.h - MLIR IR Bindings -------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_PYTHON_MLIR_IR_H
|
||||||
|
#define NPCOMP_PYTHON_MLIR_IR_H
|
||||||
|
|
||||||
|
#include "pybind_utils.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Region.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
struct PyContext;
|
||||||
|
|
||||||
|
/// Wrapper around an Operation*.
|
||||||
|
struct PyBaseOperation {
|
||||||
|
virtual ~PyBaseOperation();
|
||||||
|
static void bind(py::module m);
|
||||||
|
virtual Operation *getOperation() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wrapper around Module, capturing a PyContext reference.
|
||||||
|
struct PyModuleOp : PyBaseOperation {
|
||||||
|
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
|
||||||
|
: context(context), moduleOp(moduleOp) {}
|
||||||
|
static void bind(py::module m);
|
||||||
|
Operation *getOperation() override;
|
||||||
|
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
||||||
|
int64_t largeElementLimit);
|
||||||
|
|
||||||
|
std::shared_ptr<PyContext> context;
|
||||||
|
ModuleOp moduleOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wrapper around MLIRContext.
|
||||||
|
struct PyContext : std::enable_shared_from_this<PyContext> {
|
||||||
|
static void bind(py::module m);
|
||||||
|
PyModuleOp parseAsm(const std::string &asm_text);
|
||||||
|
MLIRContext context;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wrapper around a Region&.
|
||||||
|
struct PyRegionRef {
|
||||||
|
PyRegionRef(Region ®ion) : region(region) {}
|
||||||
|
static void bind(py::module m);
|
||||||
|
Region ®ion;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wrapper around a Block&.
|
||||||
|
struct PyBlockRef {
|
||||||
|
PyBlockRef(Block &block) : block(block) {}
|
||||||
|
static void bind(py::module m);
|
||||||
|
Block █
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wrapper around an OpBuilder reference.
|
||||||
|
/// This class is inherently dangerous because it does not track ownership
|
||||||
|
/// of IR objects that it may be operating on and incorrect usage can cause
|
||||||
|
/// memory access errors, just as it can in C++. It is intended for use by
|
||||||
|
/// higher level constructs that are specifically coded to satisfy object
|
||||||
|
/// lifetime needs.
|
||||||
|
class PyBaseOpBuilder {
|
||||||
|
public:
|
||||||
|
virtual ~PyBaseOpBuilder();
|
||||||
|
static void bind(py::module m);
|
||||||
|
virtual OpBuilder &getBuilder() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wrapper around an instance of an OpBuilder.
|
||||||
|
class PyOpBuilder : public PyBaseOpBuilder {
|
||||||
|
public:
|
||||||
|
PyOpBuilder(PyContext &context) : builder(&context.context) {}
|
||||||
|
~PyOpBuilder() override;
|
||||||
|
static void bind(py::module m);
|
||||||
|
OpBuilder &getBuilder() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpBuilder builder;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_PYTHON_MLIR_IR_H
|
|
@ -18,7 +18,13 @@ module @parseSuccess {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
# CHECK: func @f
|
||||||
print(m.to_asm())
|
print(m.to_asm())
|
||||||
|
# CHECK: OP NAME: module
|
||||||
|
print("OP NAME:", m.name)
|
||||||
|
# CHECK: NUM_REGIONS: 1
|
||||||
|
print("NUM_REGIONS:", m.num_regions)
|
||||||
|
region = m.region(0)
|
||||||
|
|
||||||
# CHECK-LABEL: PARSE_FAILURE
|
# CHECK-LABEL: PARSE_FAILURE
|
||||||
print("PARSE_FAILURE")
|
print("PARSE_FAILURE")
|
||||||
|
|
|
@ -6,6 +6,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_PYTHON_PYBIND_UTILS_H
|
||||||
|
#define NPCOMP_PYTHON_PYBIND_UTILS_H
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
|
@ -48,3 +51,5 @@ inline pybind11::error_already_set raiseValueError(const std::string &message) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pybind11
|
} // namespace pybind11
|
||||||
|
|
||||||
|
#endif // NPCOMP_PYTHON_PYBIND_UTILS_H
|
||||||
|
|
Loading…
Reference in New Issue