Add wrappers for block and operation iteration.

I don't technically need this now but adding while the train of thought is fresh.
pull/1/head
Stella Laurenzo 2020-05-01 10:16:19 -07:00
parent c8740fd866
commit 23a9ffaabe
3 changed files with 124 additions and 19 deletions

View File

@ -26,6 +26,45 @@ struct PyContext;
static OwningModuleRef parseMLIRModuleFromString(StringRef contents, static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
MLIRContext *context); MLIRContext *context);
//===----------------------------------------------------------------------===//
// Internal only template definitions
//===----------------------------------------------------------------------===//
template <typename ListTy, typename ItemWrapperTy>
void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
const char *className) {
struct PyItemIterator : public llvm::iterator_adaptor_base<
PyItemIterator, typename ListTy::iterator,
typename std::iterator_traits<
typename ListTy::iterator>::iterator_category,
typename ListTy::value_type> {
PyItemIterator() = default;
PyItemIterator(typename ListTy::iterator &&other)
: PyItemIterator::iterator_adaptor_base(std::move(other)) {}
ItemWrapperTy operator*() const { return ItemWrapperTy(*this->I); }
};
py::class_<ThisTy>(m, className)
.def("__len__", [](ThisTy &self) { return self.list.size(); })
.def("__iter__",
[](ThisTy &self) {
PyItemIterator begin(self.list.begin());
PyItemIterator end(self.list.end());
return py::make_iterator(begin, end);
},
py::keep_alive<0, 1>());
}
//===----------------------------------------------------------------------===//
// Explicit template instantiations
//===----------------------------------------------------------------------===//
template class PyIpListWrapper<Region::BlockListType, PyBlockRef>;
using PyBlockList = PyIpListWrapper<Region::BlockListType, PyBlockRef>;
template class PyIpListWrapper<Block::OpListType, PyOperationRef>;
using PyOperationList = PyIpListWrapper<Block::OpListType, PyOperationRef>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Diagnostics // Diagnostics
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -70,12 +109,17 @@ private:
void defineMlirIrModule(py::module m) { void defineMlirIrModule(py::module m) {
m.doc() = "Python bindings for constructs in the mlir/IR library"; m.doc() = "Python bindings for constructs in the mlir/IR library";
PyContext::bind(m); PyBlockList::bind(m, "BlockList");
PyOperationList::bind(m, "OperationList");
PyBaseOperation::bind(m); PyBaseOperation::bind(m);
PyModuleOp::bind(m);
PyRegionRef::bind(m);
PyBaseOpBuilder::bind(m); PyBaseOpBuilder::bind(m);
PyBlockRef::bind(m);
PyContext::bind(m);
PyModuleOp::bind(m);
PyOperationRef::bind(m);
PyOpBuilder::bind(m); PyOpBuilder::bind(m);
PyRegionRef::bind(m);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -122,19 +166,19 @@ void PyBaseOperation::bind(py::module m) {
py::class_<PyBaseOperation>(m, "BaseOperation") py::class_<PyBaseOperation>(m, "BaseOperation")
.def_property_readonly( .def_property_readonly(
"name", "name",
[](PyBaseOperation *self) { [](PyBaseOperation &self) {
return std::string(self->getOperation()->getName().getStringRef()); return std::string(self.getOperation()->getName().getStringRef());
}) })
.def_property_readonly("is_registered", .def_property_readonly("is_registered",
[](PyBaseOperation *self) { [](PyBaseOperation &self) {
return self->getOperation()->isRegistered(); return self.getOperation()->isRegistered();
}) })
.def_property_readonly("num_regions", .def_property_readonly("num_regions",
[](PyBaseOperation *self) { [](PyBaseOperation &self) {
return self->getOperation()->getNumRegions(); return self.getOperation()->getNumRegions();
}) })
.def("region", [](PyBaseOperation *self, int index) { .def("region", [](PyBaseOperation &self, int index) {
auto *op = self->getOperation(); auto *op = self.getOperation();
if (index < 0 || index >= op->getNumRegions()) { if (index < 0 || index >= op->getNumRegions()) {
throw py::raisePyError(PyExc_IndexError, throw py::raisePyError(PyExc_IndexError,
"Region index out of bounds"); "Region index out of bounds");
@ -143,10 +187,22 @@ void PyBaseOperation::bind(py::module m) {
}); });
} }
//===----------------------------------------------------------------------===//
// PyOperationRef
//===----------------------------------------------------------------------===//
PyOperationRef::~PyOperationRef() = default;
void PyOperationRef::bind(py::module m) {
py::class_<PyOperationRef, PyBaseOperation>(m, "OperationRef");
}
Operation *PyOperationRef::getOperation() { return operation; }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PyModuleOp // PyModuleOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
PyModuleOp::~PyModuleOp() = default;
void PyModuleOp::bind(py::module m) { void PyModuleOp::bind(py::module m) {
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp") py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false, .def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
@ -296,12 +352,26 @@ DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
return sout.str(); return sout.str();
} }
//===----------------------------------------------------------------------===//
// PyBlockRef
//===----------------------------------------------------------------------===//
void PyBlockRef::bind(py::module m) {
py::class_<PyBlockRef>(m, "BlockRef")
.def_property_readonly("operations", [](PyBlockRef &self) {
return PyOperationList(self.block.getOperations());
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PyRegionRef // PyRegionRef
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void PyRegionRef::bind(py::module m) { void PyRegionRef::bind(py::module m) {
py::class_<PyRegionRef>(m, "RegionRef"); py::class_<PyRegionRef>(m, "RegionRef")
.def_property_readonly("blocks", [](PyRegionRef &self) {
return PyBlockList(self.region.getBlocks());
});
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -22,6 +22,24 @@ namespace mlir {
struct PyContext; struct PyContext;
//===----------------------------------------------------------------------===//
// Utility types
//===----------------------------------------------------------------------===//
template <typename ListTy, typename ItemWrapperTy> class PyIpListWrapper {
public:
using ThisTy = PyIpListWrapper<ListTy, ItemWrapperTy>;
static void bind(py::module m, const char *className);
PyIpListWrapper(ListTy &list) : list(list) {}
private:
ListTy &list;
};
//===----------------------------------------------------------------------===//
// Wrapper types
//===----------------------------------------------------------------------===//
/// Wrapper around an Operation*. /// Wrapper around an Operation*.
struct PyBaseOperation { struct PyBaseOperation {
virtual ~PyBaseOperation(); virtual ~PyBaseOperation();
@ -33,6 +51,7 @@ struct PyBaseOperation {
struct PyModuleOp : PyBaseOperation { struct PyModuleOp : PyBaseOperation {
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp) PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
: context(context), moduleOp(moduleOp) {} : context(context), moduleOp(moduleOp) {}
~PyModuleOp();
static void bind(py::module m); static void bind(py::module m);
Operation *getOperation() override; Operation *getOperation() override;
std::string toAsm(bool enableDebugInfo, bool prettyForm, std::string toAsm(bool enableDebugInfo, bool prettyForm,
@ -42,6 +61,17 @@ struct PyModuleOp : PyBaseOperation {
ModuleOp moduleOp; ModuleOp moduleOp;
}; };
/// Wrapper around an Operation*.
struct PyOperationRef : PyBaseOperation {
PyOperationRef(Operation *operation) : operation(operation) {}
PyOperationRef(Operation &operation) : operation(&operation) {}
~PyOperationRef();
static void bind(py::module m);
Operation *getOperation() override;
Operation *operation;
};
/// Wrapper around MLIRContext. /// Wrapper around MLIRContext.
struct PyContext : std::enable_shared_from_this<PyContext> { struct PyContext : std::enable_shared_from_this<PyContext> {
static void bind(py::module m); static void bind(py::module m);
@ -49,13 +79,6 @@ struct PyContext : std::enable_shared_from_this<PyContext> {
MLIRContext context; MLIRContext context;
}; };
/// Wrapper around a Region&.
struct PyRegionRef {
PyRegionRef(Region &region) : region(region) {}
static void bind(py::module m);
Region &region;
};
/// Wrapper around a Block&. /// Wrapper around a Block&.
struct PyBlockRef { struct PyBlockRef {
PyBlockRef(Block &block) : block(block) {} PyBlockRef(Block &block) : block(block) {}
@ -63,6 +86,13 @@ struct PyBlockRef {
Block &block; Block &block;
}; };
/// Wrapper around a Region&.
struct PyRegionRef {
PyRegionRef(Region &region) : region(region) {}
static void bind(py::module m);
Region &region;
};
/// Wrapper around an OpBuilder reference. /// Wrapper around an OpBuilder reference.
/// This class is inherently dangerous because it does not track ownership /// This class is inherently dangerous because it does not track ownership
/// of IR objects that it may be operating on and incorrect usage can cause /// of IR objects that it may be operating on and incorrect usage can cause

View File

@ -25,6 +25,11 @@ print("OP NAME:", m.name)
# CHECK: NUM_REGIONS: 1 # CHECK: NUM_REGIONS: 1
print("NUM_REGIONS:", m.num_regions) print("NUM_REGIONS:", m.num_regions)
region = m.region(0) region = m.region(0)
# CHECK: CONTAINED OP: func
# CHECK: CONTAINED OP: module_terminator
for block in region.blocks:
for op in block.operations:
print("CONTAINED OP:", op.name)
# CHECK-LABEL: PARSE_FAILURE # CHECK-LABEL: PARSE_FAILURE
print("PARSE_FAILURE") print("PARSE_FAILURE")