mirror of https://github.com/llvm/torch-mlir
77 lines
3.2 KiB
C++
77 lines
3.2 KiB
C++
|
//===- NpcompDialect.cpp - Custom dialect classes -------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "MlirIr.h"
|
||
|
#include "NpcompModule.h"
|
||
|
|
||
|
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||
|
#include "npcomp/Dialect/Basicpy/BasicpyOps.h"
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace NPCOMP {
|
||
|
|
||
|
class BasicpyDialectHelper : public PyDialectHelper {
|
||
|
public:
|
||
|
using PyDialectHelper::PyDialectHelper;
|
||
|
static void bind(py::module m) {
|
||
|
py::class_<BasicpyDialectHelper, PyDialectHelper>(m, "BasicpyDialectHelper")
|
||
|
.def(py::init<std::shared_ptr<PyContext>>())
|
||
|
.def_property_readonly("basicpy_NoneType",
|
||
|
[](BasicpyDialectHelper &self) -> PyType {
|
||
|
return Basicpy::NoneType::get(
|
||
|
&self.context->context);
|
||
|
})
|
||
|
.def_property_readonly("basicpy_EllipsisType",
|
||
|
[](BasicpyDialectHelper &self) -> PyType {
|
||
|
return Basicpy::EllipsisType::get(
|
||
|
&self.context->context);
|
||
|
})
|
||
|
.def("basicpy_SlotObject_type",
|
||
|
[](BasicpyDialectHelper &self, std::string className,
|
||
|
py::args pySlotTypes) -> PyType {
|
||
|
SmallVector<Type, 4> slotTypes;
|
||
|
for (auto pySlotType : pySlotTypes) {
|
||
|
slotTypes.push_back(pySlotType.cast<PyType>());
|
||
|
}
|
||
|
auto classNameAttr =
|
||
|
StringAttr::get(className, &self.context->context);
|
||
|
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||
|
},
|
||
|
py::arg("className"))
|
||
|
.def("basicpy_slot_object_get_op",
|
||
|
[](BasicpyDialectHelper &self, PyValue slotObject,
|
||
|
unsigned index) -> PyOperationRef {
|
||
|
auto slotObjectType = slotObject.value.getType()
|
||
|
.dyn_cast<Basicpy::SlotObjectType>();
|
||
|
if (!slotObjectType) {
|
||
|
throw py::raiseValueError("Operand must be a SlotObject");
|
||
|
}
|
||
|
if (index >= slotObjectType.getSlotCount()) {
|
||
|
throw py::raiseValueError("Out of range slot index");
|
||
|
}
|
||
|
auto resultType = slotObjectType.getSlotTypes()[index];
|
||
|
auto indexAttr = IntegerAttr::get(
|
||
|
IndexType::get(&self.context->context), index);
|
||
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||
|
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||
|
auto op = opBuilder.create<Basicpy::SlotObjectGetOp>(
|
||
|
loc, resultType, slotObject, indexAttr);
|
||
|
return op.getOperation();
|
||
|
});
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace NPCOMP
|
||
|
} // namespace mlir
|
||
|
|
||
|
using namespace ::mlir::NPCOMP;
|
||
|
|
||
|
void mlir::npcomp::python::defineNpcompDialect(py::module m) {
|
||
|
BasicpyDialectHelper::bind(m);
|
||
|
}
|