From 23a9ffaabe07c5409ba5b54fea3961eb90989dd6 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 1 May 2020 10:16:19 -0700 Subject: [PATCH] Add wrappers for block and operation iteration. I don't technically need this now but adding while the train of thought is fresh. --- python/npcomp/mlir_ir.cpp | 94 ++++++++++++++++++++++++++++++----- python/npcomp/mlir_ir.h | 44 +++++++++++++--- python/npcomp/mlir_ir_test.py | 5 ++ 3 files changed, 124 insertions(+), 19 deletions(-) diff --git a/python/npcomp/mlir_ir.cpp b/python/npcomp/mlir_ir.cpp index 348de24aa..96f72f56f 100644 --- a/python/npcomp/mlir_ir.cpp +++ b/python/npcomp/mlir_ir.cpp @@ -26,6 +26,45 @@ struct PyContext; static OwningModuleRef parseMLIRModuleFromString(StringRef contents, MLIRContext *context); +//===----------------------------------------------------------------------===// +// Internal only template definitions +//===----------------------------------------------------------------------===// + +template +void PyIpListWrapper::bind(py::module m, + const char *className) { + struct PyItemIterator : public llvm::iterator_adaptor_base< + PyItemIterator, typename ListTy::iterator, + typename std::iterator_traits< + typename ListTy::iterator>::iterator_category, + typename ListTy::value_type> { + PyItemIterator() = default; + PyItemIterator(typename ListTy::iterator &&other) + : PyItemIterator::iterator_adaptor_base(std::move(other)) {} + ItemWrapperTy operator*() const { return ItemWrapperTy(*this->I); } + }; + + py::class_(m, className) + .def("__len__", [](ThisTy &self) { return self.list.size(); }) + .def("__iter__", + [](ThisTy &self) { + PyItemIterator begin(self.list.begin()); + PyItemIterator end(self.list.end()); + return py::make_iterator(begin, end); + }, + py::keep_alive<0, 1>()); +} + +//===----------------------------------------------------------------------===// +// Explicit template instantiations +//===----------------------------------------------------------------------===// + +template class PyIpListWrapper; +using PyBlockList = PyIpListWrapper; + +template class PyIpListWrapper; +using PyOperationList = PyIpListWrapper; + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// @@ -70,12 +109,17 @@ private: void defineMlirIrModule(py::module m) { m.doc() = "Python bindings for constructs in the mlir/IR library"; - PyContext::bind(m); + PyBlockList::bind(m, "BlockList"); + PyOperationList::bind(m, "OperationList"); + PyBaseOperation::bind(m); - PyModuleOp::bind(m); - PyRegionRef::bind(m); PyBaseOpBuilder::bind(m); + PyBlockRef::bind(m); + PyContext::bind(m); + PyModuleOp::bind(m); + PyOperationRef::bind(m); PyOpBuilder::bind(m); + PyRegionRef::bind(m); } //===----------------------------------------------------------------------===// @@ -122,19 +166,19 @@ void PyBaseOperation::bind(py::module m) { py::class_(m, "BaseOperation") .def_property_readonly( "name", - [](PyBaseOperation *self) { - return std::string(self->getOperation()->getName().getStringRef()); + [](PyBaseOperation &self) { + return std::string(self.getOperation()->getName().getStringRef()); }) .def_property_readonly("is_registered", - [](PyBaseOperation *self) { - return self->getOperation()->isRegistered(); + [](PyBaseOperation &self) { + return self.getOperation()->isRegistered(); }) .def_property_readonly("num_regions", - [](PyBaseOperation *self) { - return self->getOperation()->getNumRegions(); + [](PyBaseOperation &self) { + return self.getOperation()->getNumRegions(); }) - .def("region", [](PyBaseOperation *self, int index) { - auto *op = self->getOperation(); + .def("region", [](PyBaseOperation &self, int index) { + auto *op = self.getOperation(); if (index < 0 || index >= op->getNumRegions()) { throw py::raisePyError(PyExc_IndexError, "Region index out of bounds"); @@ -143,10 +187,22 @@ void PyBaseOperation::bind(py::module m) { }); } +//===----------------------------------------------------------------------===// +// PyOperationRef +//===----------------------------------------------------------------------===// + +PyOperationRef::~PyOperationRef() = default; +void PyOperationRef::bind(py::module m) { + py::class_(m, "OperationRef"); +} + +Operation *PyOperationRef::getOperation() { return operation; } + //===----------------------------------------------------------------------===// // PyModuleOp //===----------------------------------------------------------------------===// +PyModuleOp::~PyModuleOp() = default; void PyModuleOp::bind(py::module m) { py::class_(m, "ModuleOp") .def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false, @@ -296,12 +352,26 @@ DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) { return sout.str(); } +//===----------------------------------------------------------------------===// +// PyBlockRef +//===----------------------------------------------------------------------===// + +void PyBlockRef::bind(py::module m) { + py::class_(m, "BlockRef") + .def_property_readonly("operations", [](PyBlockRef &self) { + return PyOperationList(self.block.getOperations()); + }); +} + //===----------------------------------------------------------------------===// // PyRegionRef //===----------------------------------------------------------------------===// void PyRegionRef::bind(py::module m) { - py::class_(m, "RegionRef"); + py::class_(m, "RegionRef") + .def_property_readonly("blocks", [](PyRegionRef &self) { + return PyBlockList(self.region.getBlocks()); + }); } //===----------------------------------------------------------------------===// diff --git a/python/npcomp/mlir_ir.h b/python/npcomp/mlir_ir.h index abc2f48fd..478e037bb 100644 --- a/python/npcomp/mlir_ir.h +++ b/python/npcomp/mlir_ir.h @@ -22,6 +22,24 @@ namespace mlir { struct PyContext; +//===----------------------------------------------------------------------===// +// Utility types +//===----------------------------------------------------------------------===// + +template class PyIpListWrapper { +public: + using ThisTy = PyIpListWrapper; + static void bind(py::module m, const char *className); + PyIpListWrapper(ListTy &list) : list(list) {} + +private: + ListTy &list; +}; + +//===----------------------------------------------------------------------===// +// Wrapper types +//===----------------------------------------------------------------------===// + /// Wrapper around an Operation*. struct PyBaseOperation { virtual ~PyBaseOperation(); @@ -33,6 +51,7 @@ struct PyBaseOperation { struct PyModuleOp : PyBaseOperation { PyModuleOp(std::shared_ptr context, ModuleOp moduleOp) : context(context), moduleOp(moduleOp) {} + ~PyModuleOp(); static void bind(py::module m); Operation *getOperation() override; std::string toAsm(bool enableDebugInfo, bool prettyForm, @@ -42,6 +61,17 @@ struct PyModuleOp : PyBaseOperation { ModuleOp moduleOp; }; +/// Wrapper around an Operation*. +struct PyOperationRef : PyBaseOperation { + PyOperationRef(Operation *operation) : operation(operation) {} + PyOperationRef(Operation &operation) : operation(&operation) {} + ~PyOperationRef(); + static void bind(py::module m); + Operation *getOperation() override; + + Operation *operation; +}; + /// Wrapper around MLIRContext. struct PyContext : std::enable_shared_from_this { static void bind(py::module m); @@ -49,13 +79,6 @@ struct PyContext : std::enable_shared_from_this { MLIRContext context; }; -/// Wrapper around a Region&. -struct PyRegionRef { - PyRegionRef(Region ®ion) : region(region) {} - static void bind(py::module m); - Region ®ion; -}; - /// Wrapper around a Block&. struct PyBlockRef { PyBlockRef(Block &block) : block(block) {} @@ -63,6 +86,13 @@ struct PyBlockRef { Block █ }; +/// Wrapper around a Region&. +struct PyRegionRef { + PyRegionRef(Region ®ion) : region(region) {} + static void bind(py::module m); + Region ®ion; +}; + /// Wrapper around an OpBuilder reference. /// This class is inherently dangerous because it does not track ownership /// of IR objects that it may be operating on and incorrect usage can cause diff --git a/python/npcomp/mlir_ir_test.py b/python/npcomp/mlir_ir_test.py index 3337d0986..f29dc8426 100644 --- a/python/npcomp/mlir_ir_test.py +++ b/python/npcomp/mlir_ir_test.py @@ -25,6 +25,11 @@ print("OP NAME:", m.name) # CHECK: NUM_REGIONS: 1 print("NUM_REGIONS:", m.num_regions) region = m.region(0) +# CHECK: CONTAINED OP: func +# CHECK: CONTAINED OP: module_terminator +for block in region.blocks: + for op in block.operations: + print("CONTAINED OP:", op.name) # CHECK-LABEL: PARSE_FAILURE print("PARSE_FAILURE")