mirror of https://github.com/llvm/torch-mlir
Add DialectHelper for Basicpy dialect.
* Involved native code for the types and slot_object_get ops.pull/1/head
parent
4ebf972503
commit
6b7c913e0b
|
@ -28,6 +28,7 @@ set(extension_pybind_sources
|
|||
native.cpp
|
||||
mlir_init.cpp
|
||||
mlir_ir.cpp
|
||||
npcomp_dialect.cpp
|
||||
pybind_utils.cpp
|
||||
)
|
||||
set_source_files_properties(
|
||||
|
@ -66,6 +67,7 @@ target_link_libraries(${extension_target}
|
|||
pybind11::module
|
||||
|
||||
# Local depends
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
|
||||
# Upstream depends
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ..native.dialect import BasicpyDialectHelper as _BaseDialectHelper
|
||||
from ..native.mlir import ir
|
||||
|
||||
__all__ = [
|
||||
"DialectHelper",
|
||||
]
|
||||
|
||||
|
||||
class DialectHelper(_BaseDialectHelper):
|
||||
r"""Dialect helper for the Basicpy dialect.
|
||||
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> h = DialectHelper(c)
|
||||
|
||||
Dialect Types:
|
||||
>>> h.basicpy_None_type
|
||||
!basicpy.NoneType
|
||||
>>> h.basicpy_Ellipsis_type
|
||||
!basicpy.EllipsisType
|
||||
>>> h.basicpy_SlotObject_type(
|
||||
... "foobar", h.basicpy_None_type, h.basicpy_None_type)
|
||||
!basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||
|
||||
singleton op:
|
||||
>>> m = c.new_module()
|
||||
>>> h.builder.insert_block_start(m.first_block)
|
||||
>>> _ = h.basicpy_singleton_op(h.basicpy_None_type)
|
||||
>>> m.to_asm().strip()
|
||||
'module {\n %0 = basicpy.singleton : !basicpy.NoneType\n}'
|
||||
|
||||
slot_object ops:
|
||||
>>> m = c.new_module()
|
||||
>>> h.builder.insert_block_start(m.first_block)
|
||||
>>> v0 = h.basicpy_singleton_op(h.basicpy_None_type).result
|
||||
>>> slot_object = h.basicpy_slot_object_make_op("foobar", v0, v0).result
|
||||
>>> _ = h.basicpy_slot_object_get_op(slot_object, 0)
|
||||
>>> print(m.to_asm().strip())
|
||||
module {
|
||||
%0 = basicpy.singleton : !basicpy.NoneType
|
||||
%1 = basicpy.slot_object_make(%0, %0) -> !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||
%2 = basicpy.slot_object_get %1[0] : !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
def basicpy_singleton_op(self, singleton_type):
|
||||
return self.op("basicpy.singleton", [singleton_type], [])
|
||||
|
||||
def basicpy_slot_object_make_op(self, class_name, *slot_values):
|
||||
c = self.context
|
||||
class_name_attr = c.string_attr(class_name)
|
||||
object_type = self.basicpy_SlotObject_type(class_name,
|
||||
*[v.type for v in slot_values])
|
||||
attrs = c.dictionary_attr({"className": class_name_attr})
|
||||
return self.op("basicpy.slot_object_make", [object_type], slot_values,
|
||||
attrs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -2,6 +2,7 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from npcomp.dialect import Basicpy
|
||||
from npcomp.native.mlir import ir
|
||||
|
||||
__all__ = [
|
||||
|
@ -10,7 +11,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
class DialectHelper(ir.DialectHelper):
|
||||
class DialectHelper(Basicpy.DialectHelper):
|
||||
r"""Dialect helper.
|
||||
|
||||
>>> c = ir.MLIRContext()
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir_ir.h"
|
||||
#include "native.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -117,19 +118,10 @@ private:
|
|||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Python only classes
|
||||
// PyDialectHelper
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||||
/// is intended to be subclassed on the Python side (possibly with multiple
|
||||
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||||
/// class contains helpers for std types and ops.
|
||||
class PyDialectHelper {
|
||||
public:
|
||||
PyDialectHelper(std::shared_ptr<PyContext> context)
|
||||
: pyOpBuilder(*context), context(std::move(context)) {
|
||||
}
|
||||
static void bind(py::module m) {
|
||||
void PyDialectHelper::bind(py::module m) {
|
||||
py::class_<PyDialectHelper>(m, "DialectHelper")
|
||||
.def(py::init<std::shared_ptr<PyContext>>())
|
||||
.def_property_readonly("builder",
|
||||
|
@ -179,8 +171,8 @@ public:
|
|||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
// TODO: Add function and arg/result attributes.
|
||||
FuncOp op = opBuilder.create<FuncOp>(
|
||||
loc, StringRef(name), functionType,
|
||||
FuncOp op =
|
||||
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
||||
/*attrs=*/ArrayRef<NamedAttribute>());
|
||||
if (createEntryBlock) {
|
||||
Block *entryBlock = new Block();
|
||||
|
@ -210,10 +202,10 @@ public:
|
|||
return PyType(IntegerType::get(width, &self.context->context));
|
||||
},
|
||||
py::arg("width") = 32)
|
||||
.def_property_readonly(
|
||||
"i1_type",
|
||||
.def_property_readonly("i1_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(1, &self.context->context));
|
||||
return PyType(
|
||||
IntegerType::get(1, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"i16_type",
|
||||
|
@ -230,17 +222,15 @@ public:
|
|||
[](PyDialectHelper &self) {
|
||||
return PyType(IntegerType::get(64, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"f32_type",
|
||||
.def_property_readonly("f32_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(
|
||||
FloatType::get(StandardTypes::F32, &self.context->context));
|
||||
return PyType(FloatType::get(
|
||||
StandardTypes::F32, &self.context->context));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"f64_type",
|
||||
.def_property_readonly("f64_type",
|
||||
[](PyDialectHelper &self) {
|
||||
return PyType(
|
||||
FloatType::get(StandardTypes::F64, &self.context->context));
|
||||
return PyType(FloatType::get(
|
||||
StandardTypes::F64, &self.context->context));
|
||||
})
|
||||
.def("tensor_type",
|
||||
[](PyDialectHelper &self, PyType elementType,
|
||||
|
@ -271,9 +261,6 @@ public:
|
|||
&self.context->context));
|
||||
});
|
||||
}
|
||||
PyOpBuilder pyOpBuilder;
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module initialization
|
||||
|
@ -414,6 +401,15 @@ void PyBaseOperation::bind(py::module m) {
|
|||
op->result_end());
|
||||
return results;
|
||||
})
|
||||
.def_property_readonly("result",
|
||||
[](PyBaseOperation &self) -> PyValue {
|
||||
auto *op = self.getOperation();
|
||||
if (op->getNumResults() != 1) {
|
||||
throw py::raiseValueError(
|
||||
"Operation does not have 1 result");
|
||||
}
|
||||
return op->getOpResult(0);
|
||||
})
|
||||
.def("region",
|
||||
[](PyBaseOperation &self, int index) {
|
||||
auto *op = self.getOperation();
|
||||
|
@ -672,7 +668,10 @@ void PyType::bind(py::module m) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PyValue::bind(py::module m) {
|
||||
py::class_<PyValue>(m, "Value").def("__repr__", [](PyValue &self) {
|
||||
py::class_<PyValue>(m, "Value")
|
||||
.def_property_readonly(
|
||||
"type", [](PyValue &self) -> PyType { return self.value.getType(); })
|
||||
.def("__repr__", [](PyValue &self) {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
os << self.value;
|
||||
|
|
|
@ -153,6 +153,25 @@ private:
|
|||
OpBuilder builder;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Custom types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||||
/// is intended to be subclassed on the Python side (possibly with multiple
|
||||
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||||
/// class contains helpers for std types and ops.
|
||||
class PyDialectHelper {
|
||||
public:
|
||||
PyDialectHelper(std::shared_ptr<PyContext> context)
|
||||
: pyOpBuilder(*context), context(std::move(context)) {}
|
||||
static void bind(py::module m);
|
||||
|
||||
protected:
|
||||
PyOpBuilder pyOpBuilder;
|
||||
std::shared_ptr<PyContext> context;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_MLIR_IR_H
|
||||
|
|
|
@ -9,19 +9,15 @@
|
|||
#include <cstddef>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "native.h"
|
||||
#include "pybind_utils.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
namespace mlir {
|
||||
void defineMlirIrModule(py::module m);
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
// Externs
|
||||
bool npcompMlirInitialize();
|
||||
|
||||
void defineLLVMModule(pybind11::module m) {
|
||||
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
|
||||
m.def("add_option",
|
||||
|
@ -66,6 +62,9 @@ PYBIND11_MODULE(native, m) {
|
|||
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
||||
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
||||
defineMlirIrModule(mlir_ir_m);
|
||||
|
||||
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
||||
defineNpcompDialect(npcomp_dialect);
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
//===- dialect.h - Module registrations -----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_PYTHON_NATIVE_H
|
||||
#define NPCOMP_PYTHON_NATIVE_H
|
||||
|
||||
#include "pybind_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
void defineMlirIrModule(py::module m);
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
bool npcompMlirInitialize();
|
||||
void defineNpcompDialect(py::module m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_PYTHON_NATIVE_H
|
|
@ -0,0 +1,76 @@
|
|||
//===- npcomp_dialect.cpp - Custom dialect classes ------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir_ir.h"
|
||||
#include "native.h"
|
||||
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/BasicpyOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
class BasicpyDialectHelper : public PyDialectHelper {
|
||||
public:
|
||||
using PyDialectHelper::PyDialectHelper;
|
||||
static void bind(py::module m) {
|
||||
py::class_<BasicpyDialectHelper, PyDialectHelper>(m, "BasicpyDialectHelper")
|
||||
.def(py::init<std::shared_ptr<PyContext>>())
|
||||
.def_property_readonly("basicpy_None_type",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::NoneType::get(
|
||||
&self.context->context);
|
||||
})
|
||||
.def_property_readonly("basicpy_Ellipsis_type",
|
||||
[](BasicpyDialectHelper &self) -> PyType {
|
||||
return Basicpy::EllipsisType::get(
|
||||
&self.context->context);
|
||||
})
|
||||
.def("basicpy_SlotObject_type",
|
||||
[](BasicpyDialectHelper &self, std::string className,
|
||||
py::args pySlotTypes) -> PyType {
|
||||
SmallVector<Type, 4> slotTypes;
|
||||
for (auto pySlotType : pySlotTypes) {
|
||||
slotTypes.push_back(pySlotType.cast<PyType>());
|
||||
}
|
||||
auto classNameAttr =
|
||||
StringAttr::get(className, &self.context->context);
|
||||
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||||
},
|
||||
py::arg("className"))
|
||||
.def("basicpy_slot_object_get_op",
|
||||
[](BasicpyDialectHelper &self, PyValue slotObject,
|
||||
unsigned index) -> PyOperationRef {
|
||||
auto slotObjectType = slotObject.value.getType()
|
||||
.dyn_cast<Basicpy::SlotObjectType>();
|
||||
if (!slotObjectType) {
|
||||
throw py::raiseValueError("Operand must be a SlotObject");
|
||||
}
|
||||
if (index >= slotObjectType.getSlotCount()) {
|
||||
throw py::raiseValueError("Out of range slot index");
|
||||
}
|
||||
auto resultType = slotObjectType.getSlotTypes()[index];
|
||||
auto indexAttr = IntegerAttr::get(
|
||||
IndexType::get(&self.context->context), index);
|
||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||
auto op = opBuilder.create<Basicpy::SlotObjectGetOp>(
|
||||
loc, resultType, slotObject, indexAttr);
|
||||
return op.getOperation();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
using namespace ::mlir::NPCOMP;
|
||||
|
||||
void mlir::npcomp::python::defineNpcompDialect(py::module m) {
|
||||
BasicpyDialectHelper::bind(m);
|
||||
}
|
|
@ -212,10 +212,8 @@ class GenericArrayFuncEmitter(FuncEmitter):
|
|||
|
||||
def emit(self, request: EmissionRequest):
|
||||
h = request.dialect_helper
|
||||
op_result_types = [h.tensor_type(h.numpy_any_dtype)
|
||||
] * self._nresults
|
||||
op = h.op(self._op_name, op_result_types,
|
||||
request.input_ssa_values)
|
||||
op_result_types = [h.tensor_type(h.numpy_any_dtype)] * self._nresults
|
||||
op = h.op(self._op_name, op_result_types, request.input_ssa_values)
|
||||
return op.results
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue