//===- NpcompDialect.cpp - Custom dialect classes -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "MlirIr.h" #include "NpcompModule.h" #include "npcomp/Dialect/Basicpy/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/BasicpyOps.h" namespace mlir { namespace NPCOMP { class BasicpyDialectHelper : public PyDialectHelper { public: using PyDialectHelper::PyDialectHelper; static void bind(py::module m) { py::class_(m, "BasicpyDialectHelper") .def(py::init>()) .def_property_readonly("basicpy_NoneType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::NoneType::get( &self.context->context); }) .def_property_readonly("basicpy_EllipsisType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::EllipsisType::get( &self.context->context); }) .def("basicpy_SlotObject_type", [](BasicpyDialectHelper &self, std::string className, py::args pySlotTypes) -> PyType { SmallVector slotTypes; for (auto pySlotType : pySlotTypes) { slotTypes.push_back(pySlotType.cast()); } auto classNameAttr = StringAttr::get(className, &self.context->context); return Basicpy::SlotObjectType::get(classNameAttr, slotTypes); }, py::arg("className")) .def("basicpy_slot_object_get_op", [](BasicpyDialectHelper &self, PyValue slotObject, unsigned index) -> PyOperationRef { auto slotObjectType = slotObject.value.getType() .dyn_cast(); if (!slotObjectType) { throw py::raiseValueError("Operand must be a SlotObject"); } if (index >= slotObjectType.getSlotCount()) { throw py::raiseValueError("Out of range slot index"); } auto resultType = slotObjectType.getSlotTypes()[index]; auto indexAttr = IntegerAttr::get( IndexType::get(&self.context->context), index); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = UnknownLoc::get(opBuilder.getContext()); auto op = opBuilder.create( loc, resultType, slotObject, indexAttr); return op.getOperation(); }); } }; } // namespace NPCOMP } // namespace mlir using namespace ::mlir::NPCOMP; void mlir::npcomp::python::defineNpcompDialect(py::module m) { BasicpyDialectHelper::bind(m); }