mirror of https://github.com/llvm/torch-mlir
Add enough python bindings to build functions and ufunc calls.
parent
ba0c96b51a
commit
78a8e6ec9e
|
@ -2,14 +2,72 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from npcomp.native.mlir.ir import *
|
||||
from npcomp.native.mlir import ir
|
||||
|
||||
__all__ = [
|
||||
"load_builtin_module",
|
||||
"Types",
|
||||
]
|
||||
|
||||
|
||||
def load_builtin_module(context = None):
|
||||
class Ops(ir.Ops):
|
||||
r"""Dialect ops.
|
||||
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> t = Types(c)
|
||||
>>> m = c.new_module()
|
||||
>>> tensor_type = t.tensor(t.f32)
|
||||
>>> ops = Ops(c)
|
||||
>>> ops.builder.insert_block_start(m.first_block)
|
||||
>>> f = ops.func_op("foobar", t.function(
|
||||
... [tensor_type, tensor_type], [tensor_type]),
|
||||
... create_entry_block=True)
|
||||
>>> uf = ops.numpy_ufunc_call("numpy.add", tensor_type,
|
||||
... *f.first_block.args)
|
||||
>>> _ = ops.return_op(uf.results)
|
||||
>>> print(m.to_asm())
|
||||
<BLANKLINE>
|
||||
<BLANKLINE>
|
||||
module {
|
||||
func @foobar(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = numpy.ufunc_call @numpy.add(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
}
|
||||
"""
|
||||
def numpy_ufunc_call(self, callee_symbol, result_type, *args):
|
||||
"""Creates a numpy.ufunc_call op."""
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({
|
||||
"ufunc_ref": c.flat_symbol_ref_attr(callee_symbol)
|
||||
})
|
||||
return self.op("numpy.ufunc_call", [result_type], args, attrs)
|
||||
|
||||
|
||||
class Types(ir.Types):
|
||||
"""Container/factory for dialect types.
|
||||
|
||||
>>> t = Types(ir.MLIRContext())
|
||||
>>> t.any_dtype
|
||||
!numpy.any_dtype
|
||||
>>> t.tensor(t.any_dtype, [1, 2, 3])
|
||||
tensor<1x2x3x!numpy.any_dtype>
|
||||
>>> t.tensor(t.any_dtype)
|
||||
tensor<*x!numpy.any_dtype>
|
||||
>>> t.tensor(t.any_dtype, [-1, 2])
|
||||
tensor<?x2x!numpy.any_dtype>
|
||||
>>> t.tensor(t.f32)
|
||||
tensor<*xf32>
|
||||
>>> t.function([t.i32], [t.f32])
|
||||
(i32) -> f32
|
||||
|
||||
"""
|
||||
def __init__(self, context):
|
||||
super().__init__(context)
|
||||
self.any_dtype = context.parse_type("!numpy.any_dtype")
|
||||
|
||||
|
||||
def load_builtin_module(context=None):
|
||||
"""Loads a module populated with numpy built-ins.
|
||||
|
||||
This is not a long-term solution but overcomes some bootstrapping
|
||||
|
@ -17,15 +75,17 @@ def load_builtin_module(context = None):
|
|||
|
||||
>>> m = load_builtin_module()
|
||||
>>> op = m.region(0).blocks.front.operations.front
|
||||
>>> print(op.name)
|
||||
numpy.generic_ufunc
|
||||
>>> op.is_registered
|
||||
True
|
||||
>>> op.name
|
||||
'numpy.generic_ufunc'
|
||||
|
||||
Args:
|
||||
context: The MLIRContext to use (None to create a new one).
|
||||
Returns:
|
||||
A ModuleOp.
|
||||
"""
|
||||
if context is None: context = MLIRContext()
|
||||
if context is None: context = ir.MLIRContext()
|
||||
return context.parse_asm(_BUILTIN_MODULE_ASM)
|
||||
|
||||
|
||||
|
|
|
@ -8,8 +8,12 @@
|
|||
|
||||
#include "mlir_ir.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -112,12 +116,184 @@ private:
|
|||
mlir::DiagnosticEngine::HandlerID handler_id;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Python only classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class PyOps {
|
||||
public:
|
||||
PyOps(std::shared_ptr<PyContext> context)
|
||||
: pyOpBuilder(*context), context(std::move(context)) {}
|
||||
static void bind(py::module m) {
|
||||
py::class_<PyOps>(m, "Ops")
|
||||
.def(py::init<std::shared_ptr<PyContext>>())
|
||||
.def_property_readonly(
|
||||
"builder",
|
||||
[](PyOps &self) -> PyBaseOpBuilder & { return self.pyOpBuilder; })
|
||||
.def_property_readonly("context",
|
||||
[](PyOps &self) -> std::shared_ptr<PyContext> {
|
||||
return self.context;
|
||||
})
|
||||
.def("op",
|
||||
[](PyOps &self, const std::string &opNameStr,
|
||||
std::vector<PyType> pyResultTypes,
|
||||
std::vector<PyValue> pyOperands,
|
||||
llvm::Optional<PyAttribute> attrs) -> PyOperationRef {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
OperationName opName(opNameStr, opBuilder.getContext());
|
||||
SmallVector<Type, 4> types(pyResultTypes.begin(),
|
||||
pyResultTypes.end());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(),
|
||||
pyOperands.end());
|
||||
NamedAttributeList attrList;
|
||||
if (attrs) {
|
||||
auto dictAttrs = attrs->attr.dyn_cast<DictionaryAttr>();
|
||||
if (!dictAttrs) {
|
||||
throw py::raiseValueError(
|
||||
"Expected `attrs` to be a DictionaryAttr");
|
||||
}
|
||||
attrList = NamedAttributeList(dictAttrs);
|
||||
}
|
||||
Operation *op =
|
||||
Operation::create(loc, opName, types, operands, attrList);
|
||||
opBuilder.insert(op);
|
||||
return op;
|
||||
},
|
||||
py::arg("op_name"), py::arg("result_types"), py::arg("operands"),
|
||||
py::arg("attrs") = llvm::Optional<PyAttribute>())
|
||||
.def("func_op",
|
||||
[](PyOps &self, const std::string &name, PyType type,
|
||||
bool createEntryBlock) {
|
||||
auto functionType = type.type.dyn_cast_or_null<FunctionType>();
|
||||
if (!functionType) {
|
||||
throw py::raiseValueError("Illegal function type");
|
||||
}
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
// TODO: Add function and arg/result attributes.
|
||||
FuncOp op = opBuilder.create<FuncOp>(
|
||||
loc, StringRef(name), functionType,
|
||||
/*attrs=*/ArrayRef<NamedAttribute>());
|
||||
if (createEntryBlock) {
|
||||
Block *entryBlock = new Block();
|
||||
entryBlock->addArguments(functionType.getInputs());
|
||||
op.getBody().push_back(entryBlock);
|
||||
opBuilder.setInsertionPointToStart(entryBlock);
|
||||
}
|
||||
return PyOperationRef(op);
|
||||
},
|
||||
py::arg("name"), py::arg("type"),
|
||||
py::arg("create_entry_block") = false,
|
||||
R"(Creates a new `func` op, optionally creating an entry block.
|
||||
If an entry block is created, the builder will be positioned
|
||||
to its start.)")
|
||||
.def("return_op", [](PyOps &self, std::vector<PyValue> pyOperands) {
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
SmallVector<Value, 4> operands(pyOperands.begin(), pyOperands.end());
|
||||
return PyOperationRef(opBuilder.create<ReturnOp>(loc, operands));
|
||||
});
|
||||
}
|
||||
PyOpBuilder pyOpBuilder;
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
class PyTypes {
|
||||
public:
|
||||
PyTypes(std::shared_ptr<PyContext> context) : context(std::move(context)) {}
|
||||
static void bind(py::module m) {
|
||||
py::class_<PyTypes>(m, "Types")
|
||||
.def(py::init<std::shared_ptr<PyContext>>())
|
||||
.def_property_readonly("context",
|
||||
[](PyTypes &self) { return self.context; })
|
||||
.def("integer",
|
||||
[](PyTypes &self, unsigned width) {
|
||||
return PyType(IntegerType::get(width, &self.context->context));
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def_property_readonly(
|
||||
"i1",
|
||||
[](PyTypes &self) {
|
||||
return PyType(IntegerType::get(1, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i16",
|
||||
[](PyTypes &self) {
|
||||
return PyType(IntegerType::get(32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i32",
|
||||
[](PyTypes &self) {
|
||||
return PyType(IntegerType::get(32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i64",
|
||||
[](PyTypes &self) {
|
||||
return PyType(IntegerType::get(64, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"f32",
|
||||
[](PyTypes &self) {
|
||||
return PyType(
|
||||
FloatType::get(StandardTypes::F32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"f64",
|
||||
[](PyTypes &self) {
|
||||
return PyType(
|
||||
FloatType::get(StandardTypes::F64, &self.context->context));
|
||||
})
|
||||
.def("tensor",
|
||||
[](PyTypes &self, PyType elementType,
|
||||
llvm::Optional<std::vector<int64_t>> shape) {
|
||||
if (!elementType.type) {
|
||||
throw py::raiseValueError("Null element type");
|
||||
}
|
||||
if (shape) {
|
||||
return PyType(RankedTensorType::get(*shape, elementType.type));
|
||||
} else {
|
||||
return PyType(UnrankedTensorType::get(elementType.type));
|
||||
}
|
||||
},
|
||||
py::arg("element_type"),
|
||||
py::arg("shape") = llvm::Optional<std::vector<int64_t>>())
|
||||
.def("function", [](PyTypes &self, std::vector<PyType> inputs,
|
||||
std::vector<PyType> results) {
|
||||
llvm::SmallVector<Type, 4> inputTypes;
|
||||
llvm::SmallVector<Type, 1> resultTypes;
|
||||
for (auto input : inputs) {
|
||||
inputTypes.push_back(input.type);
|
||||
}
|
||||
for (auto result : results) {
|
||||
resultTypes.push_back(result.type);
|
||||
}
|
||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
||||
&self.context->context));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module initialization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void defineMlirIrModule(py::module m) {
|
||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||
|
||||
// Python only types.
|
||||
PyOps::bind(m);
|
||||
PyTypes::bind(m);
|
||||
|
||||
// Utility types.
|
||||
PyBlockList::bind(m, "BlockList");
|
||||
PyOperationList::bind(m, "OperationList");
|
||||
|
||||
// Wrapper types.
|
||||
PyAttribute::bind(m);
|
||||
PyBaseOperation::bind(m);
|
||||
PyBaseOpBuilder::bind(m);
|
||||
PyBlockRef::bind(m);
|
||||
|
@ -126,6 +302,9 @@ void defineMlirIrModule(py::module m) {
|
|||
PyOperationRef::bind(m);
|
||||
PyOpBuilder::bind(m);
|
||||
PyRegionRef::bind(m);
|
||||
PySymbolTable::bind(m);
|
||||
PyType::bind(m);
|
||||
PyValue::bind(m);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -139,10 +318,58 @@ void PyContext::bind(py::module m) {
|
|||
return std::make_shared<PyContext>();
|
||||
}))
|
||||
.def("new_module",
|
||||
[&](PyContext &context) -> PyModuleOp {
|
||||
return PyModuleOp(context.shared_from_this(), {});
|
||||
[&](PyContext &self) -> PyModuleOp {
|
||||
Location loc = UnknownLoc::get(&self.context);
|
||||
auto m = ModuleOp::create(loc);
|
||||
return PyModuleOp(self.shared_from_this(), m);
|
||||
})
|
||||
.def("parse_asm", &PyContext::parseAsm);
|
||||
.def("parse_asm", &PyContext::parseAsm)
|
||||
.def("new_builder",
|
||||
[](PyContext &self) {
|
||||
// Note: we collapse the Builder and OpBuilder into one because
|
||||
// there is little reason to expose the inheritance hierarchy to
|
||||
// Python.
|
||||
return PyOpBuilder(self);
|
||||
})
|
||||
// Salient functions from Builder.
|
||||
.def("parse_type",
|
||||
[](PyContext &self, const std::string &asmText) {
|
||||
Type t = parseType(asmText, &self.context);
|
||||
if (!t) {
|
||||
std::string message = "Unable to parse MLIR type: ";
|
||||
message.append(asmText);
|
||||
throw py::raiseValueError(message);
|
||||
}
|
||||
return PyType(t);
|
||||
})
|
||||
.def("string_attr",
|
||||
[](PyContext &self, const std::string &s) -> PyAttribute {
|
||||
return StringAttr::get(s, &self.context);
|
||||
})
|
||||
.def("bytes_attr",
|
||||
[](PyContext &self, py::bytes bytes) -> PyAttribute {
|
||||
char *buffer;
|
||||
ssize_t length;
|
||||
if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bytes.ptr(), &buffer,
|
||||
&length)) {
|
||||
throw py::raiseValueError("Cannot extract bytes");
|
||||
}
|
||||
return StringAttr::get(StringRef(buffer, length), &self.context);
|
||||
})
|
||||
.def("flat_symbol_ref_attr",
|
||||
[](PyContext &self, const std::string &s) -> PyAttribute {
|
||||
return FlatSymbolRefAttr::get(s, &self.context);
|
||||
})
|
||||
.def("dictionary_attr", [](PyContext &self, py::dict d) -> PyAttribute {
|
||||
SmallVector<NamedAttribute, 4> attrs;
|
||||
for (auto &it : d) {
|
||||
auto key = it.first.cast<std::string>();
|
||||
auto value = it.second.cast<PyAttribute>();
|
||||
auto keyIdent = Identifier::get(key, &self.context);
|
||||
attrs.emplace_back(keyIdent, value.attr);
|
||||
}
|
||||
return DictionaryAttr::get(attrs, &self.context);
|
||||
});
|
||||
}
|
||||
|
||||
PyModuleOp PyContext::parseAsm(const std::string &asm_text) {
|
||||
|
@ -183,13 +410,33 @@ void PyBaseOperation::bind(py::module m) {
|
|||
[](PyBaseOperation &self) {
|
||||
return self.getOperation()->getNumRegions();
|
||||
})
|
||||
.def("region", [](PyBaseOperation &self, int index) {
|
||||
auto *op = self.getOperation();
|
||||
if (index < 0 || index >= op->getNumRegions()) {
|
||||
throw py::raisePyError(PyExc_IndexError,
|
||||
"Region index out of bounds");
|
||||
.def_property_readonly("results",
|
||||
[](PyBaseOperation &self) {
|
||||
auto *op = self.getOperation();
|
||||
std::vector<PyValue> results(op->result_begin(),
|
||||
op->result_end());
|
||||
return results;
|
||||
})
|
||||
.def("region",
|
||||
[](PyBaseOperation &self, int index) {
|
||||
auto *op = self.getOperation();
|
||||
if (index < 0 || index >= op->getNumRegions()) {
|
||||
throw py::raisePyError(PyExc_IndexError,
|
||||
"Region index out of bounds");
|
||||
}
|
||||
return PyRegionRef(op->getRegion(index));
|
||||
})
|
||||
.def_property_readonly("first_block", [](PyBaseOperation &self) {
|
||||
Operation *op = self.getOperation();
|
||||
assert(op);
|
||||
if (op->getNumRegions() == 0) {
|
||||
throw py::raiseValueError("Op has no regions");
|
||||
}
|
||||
return PyRegionRef(op->getRegion(index));
|
||||
auto ®ion = op->getRegion(0);
|
||||
if (region.empty()) {
|
||||
throw py::raiseValueError("Op has no blocks");
|
||||
}
|
||||
return PyBlockRef(region.front());
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -211,6 +458,8 @@ Operation *PyOperationRef::getOperation() { return operation; }
|
|||
PyModuleOp::~PyModuleOp() = default;
|
||||
void PyModuleOp::bind(py::module m) {
|
||||
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
|
||||
.def_property_readonly("context",
|
||||
[](PyModuleOp &self) { return self.context; })
|
||||
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
||||
py::arg("pretty") = false, py::arg("large_element_limit") = -1);
|
||||
}
|
||||
|
@ -313,6 +562,25 @@ void printLocation(Location loc, raw_ostream &out) {
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PySymbolTable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PySymbolTable::bind(py::module m) {
|
||||
py::class_<PySymbolTable>(m, "SymbolTable")
|
||||
.def_property_readonly_static("symbol_attr_name",
|
||||
[](const py::object &) {
|
||||
auto sr =
|
||||
SymbolTable::getSymbolAttrName();
|
||||
return py::str(sr.data(), sr.size());
|
||||
})
|
||||
.def_property_readonly_static(
|
||||
"visibility_attr_name", [](const py::object &) {
|
||||
auto sr = SymbolTable::getVisibilityAttrName();
|
||||
return py::str(sr.data(), sr.size());
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DiagnosticCapture
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -364,8 +632,14 @@ DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
|||
|
||||
void PyBlockRef::bind(py::module m) {
|
||||
py::class_<PyBlockRef>(m, "BlockRef")
|
||||
.def_property_readonly("operations", [](PyBlockRef &self) {
|
||||
return PyOperationList(self.block.getOperations());
|
||||
.def_property_readonly("operations",
|
||||
[](PyBlockRef &self) {
|
||||
return PyOperationList(
|
||||
self.block.getOperations());
|
||||
})
|
||||
.def_property_readonly("args", [](PyBlockRef &self) {
|
||||
return std::vector<PyValue>(self.block.args_begin(),
|
||||
self.block.args_end());
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -380,13 +654,62 @@ void PyRegionRef::bind(py::module m) {
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyType::bind(py::module m) {
|
||||
py::class_<PyType>(m, "Type").def("__repr__",
|
||||
[](PyType &self) -> std::string {
|
||||
if (!self.type)
|
||||
return "<undefined type>";
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
self.type.print(os);
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyValue
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyValue::bind(py::module m) {
|
||||
py::class_<PyValue>(m, "Value").def("__repr__", [](PyValue &self) {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
os << self.value;
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyAttribute
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyAttribute::bind(py::module m) {
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
.def("__repr__", [](PyAttribute &self) {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
os << self.attr;
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpBuilder implementations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyBaseOpBuilder::~PyBaseOpBuilder() = default;
|
||||
PyOpBuilder::~PyOpBuilder() = default;
|
||||
OpBuilder &PyOpBuilder::getBuilder() { return builder; }
|
||||
|
||||
OpBuilder &PyOpBuilder::getBuilder(bool requirePosition) {
|
||||
if (!builder.getBlock()) {
|
||||
throw py::raisePyError(PyExc_IndexError, "Insertion point not set");
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
void PyBaseOpBuilder::bind(py::module m) {
|
||||
py::class_<PyBaseOpBuilder>(m, "BaseOpBuilder");
|
||||
|
@ -394,7 +717,40 @@ void PyBaseOpBuilder::bind(py::module m) {
|
|||
|
||||
void PyOpBuilder::bind(py::module m) {
|
||||
py::class_<PyOpBuilder, PyBaseOpBuilder>(m, "OpBuilder")
|
||||
.def(py::init<PyContext &>());
|
||||
.def(py::init<PyContext &>())
|
||||
.def("clear_insertion_point",
|
||||
[](PyOpBuilder &self) { self.builder.clearInsertionPoint(); })
|
||||
.def("insert_op_before",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPoint(op);
|
||||
},
|
||||
"Sets the insertion point to just before the specified op.")
|
||||
.def("insert_op_after",
|
||||
[](PyOpBuilder &self, PyBaseOperation &pyOp) {
|
||||
Operation *op = pyOp.getOperation();
|
||||
self.builder.setInsertionPointAfter(op);
|
||||
},
|
||||
"Sets the insertion point to just after the specified op.")
|
||||
.def("insert_block_start",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToStart(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the start of the block.")
|
||||
.def("insert_block_end",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
self.builder.setInsertionPointToEnd(&block.block);
|
||||
},
|
||||
"Sets the insertion point to the end of the block.")
|
||||
.def("insert_before_terminator",
|
||||
[](PyOpBuilder &self, PyBlockRef block) {
|
||||
auto *terminator = block.block.getTerminator();
|
||||
if (!terminator) {
|
||||
throw py::raiseValueError("Block has no terminator");
|
||||
}
|
||||
self.builder.setInsertionPoint(terminator);
|
||||
},
|
||||
"Sets the insertion point to just before the block terminator.");
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -50,7 +51,9 @@ struct PyBaseOperation {
|
|||
/// Wrapper around Module, capturing a PyContext reference.
|
||||
struct PyModuleOp : PyBaseOperation {
|
||||
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
|
||||
: context(context), moduleOp(moduleOp) {}
|
||||
: context(context), moduleOp(moduleOp) {
|
||||
assert(moduleOp);
|
||||
}
|
||||
~PyModuleOp();
|
||||
static void bind(py::module m);
|
||||
Operation *getOperation() override;
|
||||
|
@ -63,7 +66,9 @@ struct PyModuleOp : PyBaseOperation {
|
|||
|
||||
/// Wrapper around an Operation*.
|
||||
struct PyOperationRef : PyBaseOperation {
|
||||
PyOperationRef(Operation *operation) : operation(operation) {}
|
||||
PyOperationRef(Operation *operation) : operation(operation) {
|
||||
assert(operation);
|
||||
}
|
||||
PyOperationRef(Operation &operation) : operation(&operation) {}
|
||||
~PyOperationRef();
|
||||
static void bind(py::module m);
|
||||
|
@ -72,6 +77,28 @@ struct PyOperationRef : PyBaseOperation {
|
|||
Operation *operation;
|
||||
};
|
||||
|
||||
/// Wrapper around SymbolTable.
|
||||
struct PySymbolTable {
|
||||
PySymbolTable(SymbolTable &symbolTable) : symbolTable(symbolTable) {}
|
||||
static void bind(py::module m);
|
||||
SymbolTable &symbolTable;
|
||||
};
|
||||
|
||||
/// Wrapper around Value.
|
||||
struct PyValue {
|
||||
PyValue(Value value) : value(value) { assert(value); }
|
||||
static void bind(py::module m);
|
||||
operator Value() { return value; }
|
||||
Value value;
|
||||
};
|
||||
|
||||
/// Wrapper around Attribute.
|
||||
struct PyAttribute {
|
||||
PyAttribute(Attribute attr) : attr(attr) { assert(attr); }
|
||||
static void bind(py::module m);
|
||||
Attribute attr;
|
||||
};
|
||||
|
||||
/// Wrapper around MLIRContext.
|
||||
struct PyContext : std::enable_shared_from_this<PyContext> {
|
||||
static void bind(py::module m);
|
||||
|
@ -93,6 +120,14 @@ struct PyRegionRef {
|
|||
Region ®ion;
|
||||
};
|
||||
|
||||
struct PyType {
|
||||
PyType() = default;
|
||||
PyType(Type type) : type(type) {}
|
||||
static void bind(py::module m);
|
||||
operator Type() { return type; }
|
||||
Type type;
|
||||
};
|
||||
|
||||
/// Wrapper around an OpBuilder reference.
|
||||
/// This class is inherently dangerous because it does not track ownership
|
||||
/// of IR objects that it may be operating on and incorrect usage can cause
|
||||
|
@ -103,7 +138,7 @@ class PyBaseOpBuilder {
|
|||
public:
|
||||
virtual ~PyBaseOpBuilder();
|
||||
static void bind(py::module m);
|
||||
virtual OpBuilder &getBuilder() = 0;
|
||||
virtual OpBuilder &getBuilder(bool requirePosition = false) = 0;
|
||||
};
|
||||
|
||||
/// Wrapper around an instance of an OpBuilder.
|
||||
|
@ -112,7 +147,7 @@ public:
|
|||
PyOpBuilder(PyContext &context) : builder(&context.context) {}
|
||||
~PyOpBuilder() override;
|
||||
static void bind(py::module m);
|
||||
OpBuilder &getBuilder() override;
|
||||
OpBuilder &getBuilder(bool requirePosition = false) override;
|
||||
|
||||
private:
|
||||
OpBuilder builder;
|
||||
|
|
Loading…
Reference in New Issue