//===- mlir_ir.cpp - MLIR IR Bindings -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir_ir.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Function.h" #include "mlir/IR/Location.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" namespace mlir { //===----------------------------------------------------------------------===// // Forward declarations //===----------------------------------------------------------------------===// struct PyContext; /// Parses an MLIR module from a string. /// For maximum efficiency, the `contents` should be zero terminated. static OwningModuleRef parseMLIRModuleFromString(StringRef contents, MLIRContext *context); //===----------------------------------------------------------------------===// // Internal only template definitions // Since it is only legal to use explicit instantiations of templates in // mlir_ir.h, implementations are kept in this module to keep things scoped // well for the compiler. //===----------------------------------------------------------------------===// template void PyIpListWrapper::bind(py::module m, const char *className) { struct PyItemIterator : public llvm::iterator_adaptor_base< PyItemIterator, typename ListTy::iterator, typename std::iterator_traits< typename ListTy::iterator>::iterator_category, typename ListTy::value_type> { PyItemIterator() = default; PyItemIterator(typename ListTy::iterator &&other) : PyItemIterator::iterator_adaptor_base(std::move(other)) {} ItemWrapperTy operator*() const { return ItemWrapperTy(*this->I); } }; py::class_(m, className) .def_property_readonly( "front", [](ThisTy &self) { return ItemWrapperTy(self.list.front()); }) .def("__len__", [](ThisTy &self) { return self.list.size(); }) .def("__iter__", [](ThisTy &self) { PyItemIterator begin(self.list.begin()); PyItemIterator end(self.list.end()); return py::make_iterator(begin, end); }, py::keep_alive<0, 1>()); } //===----------------------------------------------------------------------===// // Explicit template instantiations //===----------------------------------------------------------------------===// template class PyIpListWrapper; using PyBlockList = PyIpListWrapper; template class PyIpListWrapper; using PyOperationList = PyIpListWrapper; //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// /// RAII class to capture diagnostics for later reporting back to the python /// layer. class DiagnosticCapture { public: DiagnosticCapture(mlir::MLIRContext *mlir_context) : mlir_context(mlir_context) { handler_id = mlir_context->getDiagEngine().registerHandler( [&](Diagnostic &d) -> LogicalResult { diagnostics.push_back(std::move(d)); return success(); }); } ~DiagnosticCapture() { if (mlir_context) { mlir_context->getDiagEngine().eraseHandler(handler_id); } } DiagnosticCapture(DiagnosticCapture &&other) { mlir_context = other.mlir_context; diagnostics.swap(other.diagnostics); handler_id = other.handler_id; other.mlir_context = nullptr; } std::vector &getDiagnostics() { return diagnostics; } // Consumes/clears diagnostics. std::string consumeDiagnosticsAsString(const char *error_message); void clearDiagnostics() { diagnostics.clear(); } private: MLIRContext *mlir_context; std::vector diagnostics; mlir::DiagnosticEngine::HandlerID handler_id; }; //===----------------------------------------------------------------------===// // Python only classes //===----------------------------------------------------------------------===// class PyOps { public: PyOps(std::shared_ptr context) : pyOpBuilder(*context), context(std::move(context)) {} static void bind(py::module m) { py::class_(m, "Ops") .def(py::init>()) .def_property_readonly( "builder", [](PyOps &self) -> PyBaseOpBuilder & { return self.pyOpBuilder; }) .def_property_readonly("context", [](PyOps &self) -> std::shared_ptr { return self.context; }) .def("op", [](PyOps &self, const std::string &opNameStr, std::vector pyResultTypes, std::vector pyOperands, llvm::Optional attrs) -> PyOperationRef { OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(false); Location loc = UnknownLoc::get(opBuilder.getContext()); OperationName opName(opNameStr, opBuilder.getContext()); SmallVector types(pyResultTypes.begin(), pyResultTypes.end()); SmallVector operands(pyOperands.begin(), pyOperands.end()); MutableDictionaryAttr attrList; if (attrs) { auto dictAttrs = attrs->attr.dyn_cast(); if (!dictAttrs) { throw py::raiseValueError( "Expected `attrs` to be a DictionaryAttr"); } attrList = MutableDictionaryAttr(dictAttrs); } Operation *op = Operation::create(loc, opName, types, operands, attrList); opBuilder.insert(op); return op; }, py::arg("op_name"), py::arg("result_types"), py::arg("operands"), py::arg("attrs") = llvm::Optional()) .def("func_op", [](PyOps &self, const std::string &name, PyType type, bool createEntryBlock) { auto functionType = type.type.dyn_cast_or_null(); if (!functionType) { throw py::raiseValueError("Illegal function type"); } OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = UnknownLoc::get(opBuilder.getContext()); // TODO: Add function and arg/result attributes. FuncOp op = opBuilder.create( loc, StringRef(name), functionType, /*attrs=*/ArrayRef()); if (createEntryBlock) { Block *entryBlock = new Block(); entryBlock->addArguments(functionType.getInputs()); op.getBody().push_back(entryBlock); opBuilder.setInsertionPointToStart(entryBlock); } return PyOperationRef(op); }, py::arg("name"), py::arg("type"), py::arg("create_entry_block") = false, R"(Creates a new `func` op, optionally creating an entry block. If an entry block is created, the builder will be positioned to its start.)") .def("return_op", [](PyOps &self, std::vector pyOperands) { OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = UnknownLoc::get(opBuilder.getContext()); SmallVector operands(pyOperands.begin(), pyOperands.end()); return PyOperationRef(opBuilder.create(loc, operands)); }); } PyOpBuilder pyOpBuilder; std::shared_ptr context; }; class PyTypes { public: PyTypes(std::shared_ptr context) : context(std::move(context)) {} static void bind(py::module m) { py::class_(m, "Types") .def(py::init>()) .def_property_readonly("context", [](PyTypes &self) { return self.context; }) .def("integer", [](PyTypes &self, unsigned width) { return PyType(IntegerType::get(width, &self.context->context)); }, py::arg("width") = 32) .def_property_readonly( "i1", [](PyTypes &self) { return PyType(IntegerType::get(1, &self.context->context)); }) .def_property_readonly( "i16", [](PyTypes &self) { return PyType(IntegerType::get(32, &self.context->context)); }) .def_property_readonly( "i32", [](PyTypes &self) { return PyType(IntegerType::get(32, &self.context->context)); }) .def_property_readonly( "i64", [](PyTypes &self) { return PyType(IntegerType::get(64, &self.context->context)); }) .def_property_readonly( "f32", [](PyTypes &self) { return PyType( FloatType::get(StandardTypes::F32, &self.context->context)); }) .def_property_readonly( "f64", [](PyTypes &self) { return PyType( FloatType::get(StandardTypes::F64, &self.context->context)); }) .def("tensor", [](PyTypes &self, PyType elementType, llvm::Optional> shape) { if (!elementType.type) { throw py::raiseValueError("Null element type"); } if (shape) { return PyType(RankedTensorType::get(*shape, elementType.type)); } else { return PyType(UnrankedTensorType::get(elementType.type)); } }, py::arg("element_type"), py::arg("shape") = llvm::Optional>()) .def("function", [](PyTypes &self, std::vector inputs, std::vector results) { llvm::SmallVector inputTypes; llvm::SmallVector resultTypes; for (auto input : inputs) { inputTypes.push_back(input.type); } for (auto result : results) { resultTypes.push_back(result.type); } return PyType(FunctionType::get(inputTypes, resultTypes, &self.context->context)); }); } private: std::shared_ptr context; }; //===----------------------------------------------------------------------===// // Module initialization //===----------------------------------------------------------------------===// void defineMlirIrModule(py::module m) { m.doc() = "Python bindings for constructs in the mlir/IR library"; // Python only types. PyOps::bind(m); PyTypes::bind(m); // Utility types. PyBlockList::bind(m, "BlockList"); PyOperationList::bind(m, "OperationList"); // Wrapper types. PyAttribute::bind(m); PyBaseOperation::bind(m); PyBaseOpBuilder::bind(m); PyBlockRef::bind(m); PyContext::bind(m); PyModuleOp::bind(m); PyOperationRef::bind(m); PyOpBuilder::bind(m); PyRegionRef::bind(m); PySymbolTable::bind(m); PyType::bind(m); PyValue::bind(m); } //===----------------------------------------------------------------------===// // PyContext //===----------------------------------------------------------------------===// void PyContext::bind(py::module m) { py::class_>(m, "MLIRContext") .def(py::init<>([]() { // Need explicit make_shared to avoid UB with enable_shared_from_this. return std::make_shared(); })) .def("new_module", [&](PyContext &self) -> PyModuleOp { Location loc = UnknownLoc::get(&self.context); auto m = ModuleOp::create(loc); return PyModuleOp(self.shared_from_this(), m); }) .def("parse_asm", &PyContext::parseAsm) .def("new_builder", [](PyContext &self) { // Note: we collapse the Builder and OpBuilder into one because // there is little reason to expose the inheritance hierarchy to // Python. return PyOpBuilder(self); }) // Salient functions from Builder. .def("parse_type", [](PyContext &self, const std::string &asmText) { Type t = parseType(asmText, &self.context); if (!t) { std::string message = "Unable to parse MLIR type: "; message.append(asmText); throw py::raiseValueError(message); } return PyType(t); }) .def("string_attr", [](PyContext &self, const std::string &s) -> PyAttribute { return StringAttr::get(s, &self.context); }) .def("bytes_attr", [](PyContext &self, py::bytes bytes) -> PyAttribute { char *buffer; ssize_t length; if (PYBIND11_BYTES_AS_STRING_AND_SIZE(bytes.ptr(), &buffer, &length)) { throw py::raiseValueError("Cannot extract bytes"); } return StringAttr::get(StringRef(buffer, length), &self.context); }) .def("flat_symbol_ref_attr", [](PyContext &self, const std::string &s) -> PyAttribute { return FlatSymbolRefAttr::get(s, &self.context); }) .def("dictionary_attr", [](PyContext &self, py::dict d) -> PyAttribute { SmallVector attrs; for (auto &it : d) { auto key = it.first.cast(); auto value = it.second.cast(); auto keyIdent = Identifier::get(key, &self.context); attrs.emplace_back(keyIdent, value.attr); } return DictionaryAttr::get(attrs, &self.context); }); } PyModuleOp PyContext::parseAsm(const std::string &asm_text) { // Arrange to get a view that includes a terminating null to avoid // additional copy. // TODO: Consider using the buffer protocol to access and avoid more copies. const char *asm_chars = asm_text.c_str(); StringRef asm_sr(asm_chars, asm_text.size() + 1); // TODO: Output non failure diagnostics (somewhere) DiagnosticCapture diag_capture(&context); auto module_ref = parseMLIRModuleFromString(asm_sr, &context); if (!module_ref) { throw py::raiseValueError( diag_capture.consumeDiagnosticsAsString("Error parsing ASM")); } return PyModuleOp{shared_from_this(), module_ref.release()}; } //===----------------------------------------------------------------------===// // PyBaseOperation //===----------------------------------------------------------------------===// PyBaseOperation::~PyBaseOperation() = default; void PyBaseOperation::bind(py::module m) { py::class_(m, "BaseOperation") .def_property_readonly( "name", [](PyBaseOperation &self) { return std::string(self.getOperation()->getName().getStringRef()); }) .def_property_readonly("is_registered", [](PyBaseOperation &self) { return self.getOperation()->isRegistered(); }) .def_property_readonly("num_regions", [](PyBaseOperation &self) { return self.getOperation()->getNumRegions(); }) .def_property_readonly("results", [](PyBaseOperation &self) { auto *op = self.getOperation(); std::vector results(op->result_begin(), op->result_end()); return results; }) .def("region", [](PyBaseOperation &self, int index) { auto *op = self.getOperation(); if (index < 0 || index >= op->getNumRegions()) { throw py::raisePyError(PyExc_IndexError, "Region index out of bounds"); } return PyRegionRef(op->getRegion(index)); }) .def_property_readonly("first_block", [](PyBaseOperation &self) { Operation *op = self.getOperation(); assert(op); if (op->getNumRegions() == 0) { throw py::raiseValueError("Op has no regions"); } auto ®ion = op->getRegion(0); if (region.empty()) { throw py::raiseValueError("Op has no blocks"); } return PyBlockRef(region.front()); }); } //===----------------------------------------------------------------------===// // PyOperationRef //===----------------------------------------------------------------------===// PyOperationRef::~PyOperationRef() = default; void PyOperationRef::bind(py::module m) { py::class_(m, "OperationRef"); } Operation *PyOperationRef::getOperation() { return operation; } //===----------------------------------------------------------------------===// // PyModuleOp //===----------------------------------------------------------------------===// PyModuleOp::~PyModuleOp() = default; void PyModuleOp::bind(py::module m) { py::class_(m, "ModuleOp") .def_property_readonly("context", [](PyModuleOp &self) { return self.context; }) .def("to_asm", &PyModuleOp::toAsm, py::arg("debug_info") = false, py::arg("pretty") = false, py::arg("large_element_limit") = -1); } Operation *PyModuleOp::getOperation() { return moduleOp; } std::string PyModuleOp::toAsm(bool enableDebugInfo, bool prettyForm, int64_t largeElementLimit) { // Print to asm. std::string asmOutput; llvm::raw_string_ostream sout(asmOutput); OpPrintingFlags printFlags; if (enableDebugInfo) { printFlags.enableDebugInfo(prettyForm); } if (largeElementLimit >= 0) { printFlags.elideLargeElementsAttrs(largeElementLimit); } moduleOp.print(sout, printFlags); return sout.str(); } static OwningModuleRef parseMLIRModuleFromString(StringRef contents, MLIRContext *context) { std::unique_ptr contents_buffer; if (contents.back() == 0) { // If it has a nul terminator, just use as-is. contents_buffer = llvm::MemoryBuffer::getMemBuffer(contents.drop_back()); } else { // Otherwise, make a copy. contents_buffer = llvm::MemoryBuffer::getMemBufferCopy(contents, "EMBED"); } llvm::SourceMgr source_mgr; source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc()); OwningModuleRef mlir_module = parseSourceFile(source_mgr, context); return mlir_module; } // Custom location printer that prints prettier, multi-line file output // suitable for human readable error messages. The standard printer just prints // a long nested expression not particularly human friendly). Note that there // is a location pretty printer in the MLIR AsmPrinter. It is private and // doesn't do any path shortening, which seems to make long Python stack traces // a bit easier to scan. // TODO: Upstream this. void printLocation(Location loc, raw_ostream &out) { switch (loc->getKind()) { case StandardAttributes::OpaqueLocation: printLocation(loc.cast().getFallbackLocation(), out); break; case StandardAttributes::UnknownLocation: out << " [unknown location]\n"; break; case StandardAttributes::FileLineColLocation: { auto line_col_loc = loc.cast(); StringRef this_filename = line_col_loc.getFilename(); auto slash_pos = this_filename.find_last_of("/\\"); // We print both the basename and extended names with a structure like // `foo.py:35:4`. Even though technically the line/col // information is redundant to include in both names, having it on both // makes it easier to paste the paths into an editor and jump to the exact // location. std::string line_col_suffix = ":" + std::to_string(line_col_loc.getLine()) + ":" + std::to_string(line_col_loc.getColumn()); bool has_basename = false; StringRef basename = this_filename; if (slash_pos != StringRef::npos) { has_basename = true; basename = this_filename.substr(slash_pos + 1); } out << " at: " << basename << line_col_suffix; if (has_basename) { StringRef extended_name = this_filename; // Print out two tabs, as basenames usually vary in length by more than // one tab width. out << "\t\t( " << extended_name << line_col_suffix << " )"; } out << "\n"; break; } case StandardAttributes::NameLocation: { auto nameLoc = loc.cast(); out << " @'" << nameLoc.getName() << "':\n"; auto childLoc = nameLoc.getChildLoc(); if (!childLoc.isa()) { out << "(...\n"; printLocation(childLoc, out); out << ")\n"; } break; } case StandardAttributes::CallSiteLocation: { auto call_site = loc.cast(); printLocation(call_site.getCaller(), out); printLocation(call_site.getCallee(), out); break; } } } //===----------------------------------------------------------------------===// // PySymbolTable //===----------------------------------------------------------------------===// void PySymbolTable::bind(py::module m) { py::class_(m, "SymbolTable") .def_property_readonly_static("symbol_attr_name", [](const py::object &) { auto sr = SymbolTable::getSymbolAttrName(); return py::str(sr.data(), sr.size()); }) .def_property_readonly_static( "visibility_attr_name", [](const py::object &) { auto sr = SymbolTable::getVisibilityAttrName(); return py::str(sr.data(), sr.size()); }); } //===----------------------------------------------------------------------===// // DiagnosticCapture //===----------------------------------------------------------------------===// std::string DiagnosticCapture::consumeDiagnosticsAsString(const char *error_message) { std::string s; llvm::raw_string_ostream sout(s); bool first = true; if (error_message) { sout << error_message; first = false; } for (auto &d : diagnostics) { if (!first) { sout << "\n\n"; } else { first = false; } switch (d.getSeverity()) { case DiagnosticSeverity::Note: sout << "[NOTE]"; break; case DiagnosticSeverity::Warning: sout << "[WARNING]"; break; case DiagnosticSeverity::Error: sout << "[ERROR]"; break; case DiagnosticSeverity::Remark: sout << "[REMARK]"; break; default: sout << "[UNKNOWN]"; } // Message. sout << ": " << d << "\n"; printLocation(d.getLocation(), sout); } diagnostics.clear(); return sout.str(); } //===----------------------------------------------------------------------===// // PyBlockRef //===----------------------------------------------------------------------===// void PyBlockRef::bind(py::module m) { py::class_(m, "BlockRef") .def_property_readonly("operations", [](PyBlockRef &self) { return PyOperationList( self.block.getOperations()); }) .def_property_readonly("args", [](PyBlockRef &self) { return std::vector(self.block.args_begin(), self.block.args_end()); }); } //===----------------------------------------------------------------------===// // PyRegionRef //===----------------------------------------------------------------------===// void PyRegionRef::bind(py::module m) { py::class_(m, "RegionRef") .def_property_readonly("blocks", [](PyRegionRef &self) { return PyBlockList(self.region.getBlocks()); }); } //===----------------------------------------------------------------------===// // PyType //===----------------------------------------------------------------------===// void PyType::bind(py::module m) { py::class_(m, "Type").def("__repr__", [](PyType &self) -> std::string { if (!self.type) return ""; std::string res; llvm::raw_string_ostream os(res); self.type.print(os); return res; }); } //===----------------------------------------------------------------------===// // PyValue //===----------------------------------------------------------------------===// void PyValue::bind(py::module m) { py::class_(m, "Value").def("__repr__", [](PyValue &self) { std::string res; llvm::raw_string_ostream os(res); os << self.value; return res; }); } //===----------------------------------------------------------------------===// // PyAttribute //===----------------------------------------------------------------------===// void PyAttribute::bind(py::module m) { py::class_(m, "Attribute") .def("__repr__", [](PyAttribute &self) { std::string res; llvm::raw_string_ostream os(res); os << self.attr; return res; }); } //===----------------------------------------------------------------------===// // OpBuilder implementations //===----------------------------------------------------------------------===// PyBaseOpBuilder::~PyBaseOpBuilder() = default; PyOpBuilder::~PyOpBuilder() = default; OpBuilder &PyOpBuilder::getBuilder(bool requirePosition) { if (!builder.getBlock()) { throw py::raisePyError(PyExc_IndexError, "Insertion point not set"); } return builder; } void PyBaseOpBuilder::bind(py::module m) { py::class_(m, "BaseOpBuilder"); } void PyOpBuilder::bind(py::module m) { py::class_(m, "OpBuilder") .def(py::init()) .def("clear_insertion_point", [](PyOpBuilder &self) { self.builder.clearInsertionPoint(); }) .def("insert_op_before", [](PyOpBuilder &self, PyBaseOperation &pyOp) { Operation *op = pyOp.getOperation(); self.builder.setInsertionPoint(op); }, "Sets the insertion point to just before the specified op.") .def("insert_op_after", [](PyOpBuilder &self, PyBaseOperation &pyOp) { Operation *op = pyOp.getOperation(); self.builder.setInsertionPointAfter(op); }, "Sets the insertion point to just after the specified op.") .def("insert_block_start", [](PyOpBuilder &self, PyBlockRef block) { self.builder.setInsertionPointToStart(&block.block); }, "Sets the insertion point to the start of the block.") .def("insert_block_end", [](PyOpBuilder &self, PyBlockRef block) { self.builder.setInsertionPointToEnd(&block.block); }, "Sets the insertion point to the end of the block.") .def("insert_before_terminator", [](PyOpBuilder &self, PyBlockRef block) { auto *terminator = block.block.getTerminator(); if (!terminator) { throw py::raiseValueError("Block has no terminator"); } self.builder.setInsertionPoint(terminator); }, "Sets the insertion point to just before the block terminator."); } } // namespace mlir