//===- NpcompDialect.cpp - Custom dialect classes -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Python/MlirIr.h" #include "npcomp/Python/NpcompModule.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h" namespace mlir { namespace NPCOMP { class BasicpyDialectHelper : public PyDialectHelper { public: using PyDialectHelper::PyDialectHelper; static void bind(py::module m) { py::class_(m, "BasicpyDialectHelper") .def(py::init(), py::keep_alive<1, 2>(), py::keep_alive<1, 3>()) // --------------------------------------------------------------------- // Basicpy dialect // --------------------------------------------------------------------- .def_property_readonly("basicpy_BoolType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::BoolType::get( self.getContext()); }) .def_property_readonly("basicpy_BytesType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::BytesType::get( self.getContext()); }) .def_property_readonly("basicpy_EllipsisType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::EllipsisType::get( self.getContext()); }) .def_property_readonly("basicpy_NoneType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::NoneType::get( self.getContext()); }) .def( "basicpy_SlotObject_type", [](BasicpyDialectHelper &self, std::string className, py::args pySlotTypes) -> PyType { SmallVector slotTypes; for (auto pySlotType : pySlotTypes) { slotTypes.push_back(pySlotType.cast()); } auto classNameAttr = StringAttr::get(className, self.getContext()); return Basicpy::SlotObjectType::get(classNameAttr, slotTypes); }, py::arg("className")) .def_property_readonly("basicpy_StrType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::StrType::get( self.getContext()); }) .def_property_readonly("basicpy_UnknownType", [](BasicpyDialectHelper &self) -> PyType { return Basicpy::UnknownType::get( self.getContext()); }) .def("basicpy_exec_op", [](BasicpyDialectHelper &self) { OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = self.pyOpBuilder.getCurrentLoc(); auto op = opBuilder.create(loc); return py::make_tuple(PyOperationRef(op), op.getBodyBuilder().saveInsertionPoint()); }) .def("basicpy_exec_discard_op", [](BasicpyDialectHelper &self, std::vector pyOperands) { OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = self.pyOpBuilder.getCurrentLoc(); llvm::SmallVector operands(pyOperands.begin(), pyOperands.end()); auto op = opBuilder.create(loc, operands); return PyOperationRef(op); }) .def("basicpy_slot_object_get_op", [](BasicpyDialectHelper &self, PyValue slotObject, unsigned index) -> PyOperationRef { auto slotObjectType = slotObject.value.getType() .dyn_cast(); if (!slotObjectType) { throw py::raiseValueError("Operand must be a SlotObject"); } if (index >= slotObjectType.getSlotCount()) { throw py::raiseValueError("Out of range slot index"); } auto resultType = slotObjectType.getSlotTypes()[index]; auto indexAttr = IntegerAttr::get(IndexType::get(self.getContext()), index); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = self.pyOpBuilder.getCurrentLoc(); auto op = opBuilder.create( loc, resultType, slotObject, indexAttr); return op.getOperation(); }) // --------------------------------------------------------------------- // Numpy dialect // --------------------------------------------------------------------- .def("numpy_copy_to_tensor_op", [](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef { auto sourceType = source.value.getType().dyn_cast(); if (!sourceType) { source.value.dump(); throw py::raiseValueError("expected ndarray type for " "numpy_copy_to_tensor_op"); } auto dtype = sourceType.getDtype(); auto optionalShape = sourceType.getOptionalShape(); TensorType tensorType; if (optionalShape) { tensorType = RankedTensorType::get(*optionalShape, dtype); } else { tensorType = UnrankedTensorType::get(dtype); } OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = self.pyOpBuilder.getCurrentLoc(); auto op = opBuilder.create( loc, tensorType, source.value); return op.getOperation(); }) .def("numpy_create_array_from_tensor_op", [](BasicpyDialectHelper &self, PyValue source) -> PyOperationRef { auto sourceType = source.value.getType().dyn_cast(); if (!sourceType) { throw py::raiseValueError("expected tensor type for " "numpy_create_array_from_tensor_op"); } auto dtype = sourceType.getElementType(); llvm::Optional> optionalShape; if (auto rankedTensorType = sourceType.dyn_cast()) { optionalShape = rankedTensorType.getShape(); } auto ndarrayType = Numpy::NdArrayType::get(dtype, optionalShape); OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true); Location loc = self.pyOpBuilder.getCurrentLoc(); auto op = opBuilder.create( loc, ndarrayType, source.value); return op.getOperation(); }) .def("numpy_NdArrayType", [](BasicpyDialectHelper &self, PyType dtype) -> PyType { return Numpy::NdArrayType::get(dtype.type); }); } }; } // namespace NPCOMP } // namespace mlir using namespace ::mlir::NPCOMP; void mlir::npcomp::python::defineNpcompDialect(py::module m) { BasicpyDialectHelper::bind(m); }