mirror of https://github.com/llvm/torch-mlir
Add wrappers for block and operation iteration.
I don't technically need this now but adding while the train of thought is fresh.pull/1/head
parent
c8740fd866
commit
23a9ffaabe
|
@ -26,6 +26,45 @@ struct PyContext;
|
||||||
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
static OwningModuleRef parseMLIRModuleFromString(StringRef contents,
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Internal only template definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <typename ListTy, typename ItemWrapperTy>
|
||||||
|
void PyIpListWrapper<ListTy, ItemWrapperTy>::bind(py::module m,
|
||||||
|
const char *className) {
|
||||||
|
struct PyItemIterator : public llvm::iterator_adaptor_base<
|
||||||
|
PyItemIterator, typename ListTy::iterator,
|
||||||
|
typename std::iterator_traits<
|
||||||
|
typename ListTy::iterator>::iterator_category,
|
||||||
|
typename ListTy::value_type> {
|
||||||
|
PyItemIterator() = default;
|
||||||
|
PyItemIterator(typename ListTy::iterator &&other)
|
||||||
|
: PyItemIterator::iterator_adaptor_base(std::move(other)) {}
|
||||||
|
ItemWrapperTy operator*() const { return ItemWrapperTy(*this->I); }
|
||||||
|
};
|
||||||
|
|
||||||
|
py::class_<ThisTy>(m, className)
|
||||||
|
.def("__len__", [](ThisTy &self) { return self.list.size(); })
|
||||||
|
.def("__iter__",
|
||||||
|
[](ThisTy &self) {
|
||||||
|
PyItemIterator begin(self.list.begin());
|
||||||
|
PyItemIterator end(self.list.end());
|
||||||
|
return py::make_iterator(begin, end);
|
||||||
|
},
|
||||||
|
py::keep_alive<0, 1>());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Explicit template instantiations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template class PyIpListWrapper<Region::BlockListType, PyBlockRef>;
|
||||||
|
using PyBlockList = PyIpListWrapper<Region::BlockListType, PyBlockRef>;
|
||||||
|
|
||||||
|
template class PyIpListWrapper<Block::OpListType, PyOperationRef>;
|
||||||
|
using PyOperationList = PyIpListWrapper<Block::OpListType, PyOperationRef>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Diagnostics
|
// Diagnostics
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -70,12 +109,17 @@ private:
|
||||||
void defineMlirIrModule(py::module m) {
|
void defineMlirIrModule(py::module m) {
|
||||||
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
m.doc() = "Python bindings for constructs in the mlir/IR library";
|
||||||
|
|
||||||
PyContext::bind(m);
|
PyBlockList::bind(m, "BlockList");
|
||||||
|
PyOperationList::bind(m, "OperationList");
|
||||||
|
|
||||||
PyBaseOperation::bind(m);
|
PyBaseOperation::bind(m);
|
||||||
PyModuleOp::bind(m);
|
|
||||||
PyRegionRef::bind(m);
|
|
||||||
PyBaseOpBuilder::bind(m);
|
PyBaseOpBuilder::bind(m);
|
||||||
|
PyBlockRef::bind(m);
|
||||||
|
PyContext::bind(m);
|
||||||
|
PyModuleOp::bind(m);
|
||||||
|
PyOperationRef::bind(m);
|
||||||
PyOpBuilder::bind(m);
|
PyOpBuilder::bind(m);
|
||||||
|
PyRegionRef::bind(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -122,19 +166,19 @@ void PyBaseOperation::bind(py::module m) {
|
||||||
py::class_<PyBaseOperation>(m, "BaseOperation")
|
py::class_<PyBaseOperation>(m, "BaseOperation")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"name",
|
"name",
|
||||||
[](PyBaseOperation *self) {
|
[](PyBaseOperation &self) {
|
||||||
return std::string(self->getOperation()->getName().getStringRef());
|
return std::string(self.getOperation()->getName().getStringRef());
|
||||||
})
|
})
|
||||||
.def_property_readonly("is_registered",
|
.def_property_readonly("is_registered",
|
||||||
[](PyBaseOperation *self) {
|
[](PyBaseOperation &self) {
|
||||||
return self->getOperation()->isRegistered();
|
return self.getOperation()->isRegistered();
|
||||||
})
|
})
|
||||||
.def_property_readonly("num_regions",
|
.def_property_readonly("num_regions",
|
||||||
[](PyBaseOperation *self) {
|
[](PyBaseOperation &self) {
|
||||||
return self->getOperation()->getNumRegions();
|
return self.getOperation()->getNumRegions();
|
||||||
})
|
})
|
||||||
.def("region", [](PyBaseOperation *self, int index) {
|
.def("region", [](PyBaseOperation &self, int index) {
|
||||||
auto *op = self->getOperation();
|
auto *op = self.getOperation();
|
||||||
if (index < 0 || index >= op->getNumRegions()) {
|
if (index < 0 || index >= op->getNumRegions()) {
|
||||||
throw py::raisePyError(PyExc_IndexError,
|
throw py::raisePyError(PyExc_IndexError,
|
||||||
"Region index out of bounds");
|
"Region index out of bounds");
|
||||||
|
@ -143,10 +187,22 @@ void PyBaseOperation::bind(py::module m) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PyOperationRef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
PyOperationRef::~PyOperationRef() = default;
|
||||||
|
void PyOperationRef::bind(py::module m) {
|
||||||
|
py::class_<PyOperationRef, PyBaseOperation>(m, "OperationRef");
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation *PyOperationRef::getOperation() { return operation; }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PyModuleOp
|
// PyModuleOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
PyModuleOp::~PyModuleOp() = default;
|
||||||
void PyModuleOp::bind(py::module m) {
|
void PyModuleOp::bind(py::module m) {
|
||||||
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
|
py::class_<PyModuleOp, PyBaseOperation>(m, "ModuleOp")
|
||||||
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
.def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false,
|
||||||
|
@ -296,12 +352,26 @@ DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) {
|
||||||
return sout.str();
|
return sout.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PyBlockRef
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PyBlockRef::bind(py::module m) {
|
||||||
|
py::class_<PyBlockRef>(m, "BlockRef")
|
||||||
|
.def_property_readonly("operations", [](PyBlockRef &self) {
|
||||||
|
return PyOperationList(self.block.getOperations());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PyRegionRef
|
// PyRegionRef
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void PyRegionRef::bind(py::module m) {
|
void PyRegionRef::bind(py::module m) {
|
||||||
py::class_<PyRegionRef>(m, "RegionRef");
|
py::class_<PyRegionRef>(m, "RegionRef")
|
||||||
|
.def_property_readonly("blocks", [](PyRegionRef &self) {
|
||||||
|
return PyBlockList(self.region.getBlocks());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -22,6 +22,24 @@ namespace mlir {
|
||||||
|
|
||||||
struct PyContext;
|
struct PyContext;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Utility types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
template <typename ListTy, typename ItemWrapperTy> class PyIpListWrapper {
|
||||||
|
public:
|
||||||
|
using ThisTy = PyIpListWrapper<ListTy, ItemWrapperTy>;
|
||||||
|
static void bind(py::module m, const char *className);
|
||||||
|
PyIpListWrapper(ListTy &list) : list(list) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ListTy &list;
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Wrapper types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Wrapper around an Operation*.
|
/// Wrapper around an Operation*.
|
||||||
struct PyBaseOperation {
|
struct PyBaseOperation {
|
||||||
virtual ~PyBaseOperation();
|
virtual ~PyBaseOperation();
|
||||||
|
@ -33,6 +51,7 @@ struct PyBaseOperation {
|
||||||
struct PyModuleOp : PyBaseOperation {
|
struct PyModuleOp : PyBaseOperation {
|
||||||
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
|
PyModuleOp(std::shared_ptr<PyContext> context, ModuleOp moduleOp)
|
||||||
: context(context), moduleOp(moduleOp) {}
|
: context(context), moduleOp(moduleOp) {}
|
||||||
|
~PyModuleOp();
|
||||||
static void bind(py::module m);
|
static void bind(py::module m);
|
||||||
Operation *getOperation() override;
|
Operation *getOperation() override;
|
||||||
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
std::string toAsm(bool enableDebugInfo, bool prettyForm,
|
||||||
|
@ -42,6 +61,17 @@ struct PyModuleOp : PyBaseOperation {
|
||||||
ModuleOp moduleOp;
|
ModuleOp moduleOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Wrapper around an Operation*.
|
||||||
|
struct PyOperationRef : PyBaseOperation {
|
||||||
|
PyOperationRef(Operation *operation) : operation(operation) {}
|
||||||
|
PyOperationRef(Operation &operation) : operation(&operation) {}
|
||||||
|
~PyOperationRef();
|
||||||
|
static void bind(py::module m);
|
||||||
|
Operation *getOperation() override;
|
||||||
|
|
||||||
|
Operation *operation;
|
||||||
|
};
|
||||||
|
|
||||||
/// Wrapper around MLIRContext.
|
/// Wrapper around MLIRContext.
|
||||||
struct PyContext : std::enable_shared_from_this<PyContext> {
|
struct PyContext : std::enable_shared_from_this<PyContext> {
|
||||||
static void bind(py::module m);
|
static void bind(py::module m);
|
||||||
|
@ -49,13 +79,6 @@ struct PyContext : std::enable_shared_from_this<PyContext> {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Wrapper around a Region&.
|
|
||||||
struct PyRegionRef {
|
|
||||||
PyRegionRef(Region ®ion) : region(region) {}
|
|
||||||
static void bind(py::module m);
|
|
||||||
Region ®ion;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Wrapper around a Block&.
|
/// Wrapper around a Block&.
|
||||||
struct PyBlockRef {
|
struct PyBlockRef {
|
||||||
PyBlockRef(Block &block) : block(block) {}
|
PyBlockRef(Block &block) : block(block) {}
|
||||||
|
@ -63,6 +86,13 @@ struct PyBlockRef {
|
||||||
Block █
|
Block █
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Wrapper around a Region&.
|
||||||
|
struct PyRegionRef {
|
||||||
|
PyRegionRef(Region ®ion) : region(region) {}
|
||||||
|
static void bind(py::module m);
|
||||||
|
Region ®ion;
|
||||||
|
};
|
||||||
|
|
||||||
/// Wrapper around an OpBuilder reference.
|
/// Wrapper around an OpBuilder reference.
|
||||||
/// This class is inherently dangerous because it does not track ownership
|
/// This class is inherently dangerous because it does not track ownership
|
||||||
/// of IR objects that it may be operating on and incorrect usage can cause
|
/// of IR objects that it may be operating on and incorrect usage can cause
|
||||||
|
|
|
@ -25,6 +25,11 @@ print("OP NAME:", m.name)
|
||||||
# CHECK: NUM_REGIONS: 1
|
# CHECK: NUM_REGIONS: 1
|
||||||
print("NUM_REGIONS:", m.num_regions)
|
print("NUM_REGIONS:", m.num_regions)
|
||||||
region = m.region(0)
|
region = m.region(0)
|
||||||
|
# CHECK: CONTAINED OP: func
|
||||||
|
# CHECK: CONTAINED OP: module_terminator
|
||||||
|
for block in region.blocks:
|
||||||
|
for op in block.operations:
|
||||||
|
print("CONTAINED OP:", op.name)
|
||||||
|
|
||||||
# CHECK-LABEL: PARSE_FAILURE
|
# CHECK-LABEL: PARSE_FAILURE
|
||||||
print("PARSE_FAILURE")
|
print("PARSE_FAILURE")
|
||||||
|
|
Loading…
Reference in New Issue