mirror of https://github.com/llvm/torch-mlir
Add DialectHelper for Basicpy dialect.
* Involved native code for the types and slot_object_get ops.pull/1/head
parent
4ebf972503
commit
6b7c913e0b
|
@ -28,6 +28,7 @@ set(extension_pybind_sources
|
||||||
native.cpp
|
native.cpp
|
||||||
mlir_init.cpp
|
mlir_init.cpp
|
||||||
mlir_ir.cpp
|
mlir_ir.cpp
|
||||||
|
npcomp_dialect.cpp
|
||||||
pybind_utils.cpp
|
pybind_utils.cpp
|
||||||
)
|
)
|
||||||
set_source_files_properties(
|
set_source_files_properties(
|
||||||
|
@ -66,6 +67,7 @@ target_link_libraries(${extension_target}
|
||||||
pybind11::module
|
pybind11::module
|
||||||
|
|
||||||
# Local depends
|
# Local depends
|
||||||
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
|
|
||||||
# Upstream depends
|
# Upstream depends
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from ..native.dialect import BasicpyDialectHelper as _BaseDialectHelper
|
||||||
|
from ..native.mlir import ir
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DialectHelper",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DialectHelper(_BaseDialectHelper):
|
||||||
|
r"""Dialect helper for the Basicpy dialect.
|
||||||
|
|
||||||
|
>>> c = ir.MLIRContext()
|
||||||
|
>>> h = DialectHelper(c)
|
||||||
|
|
||||||
|
Dialect Types:
|
||||||
|
>>> h.basicpy_None_type
|
||||||
|
!basicpy.NoneType
|
||||||
|
>>> h.basicpy_Ellipsis_type
|
||||||
|
!basicpy.EllipsisType
|
||||||
|
>>> h.basicpy_SlotObject_type(
|
||||||
|
... "foobar", h.basicpy_None_type, h.basicpy_None_type)
|
||||||
|
!basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||||
|
|
||||||
|
singleton op:
|
||||||
|
>>> m = c.new_module()
|
||||||
|
>>> h.builder.insert_block_start(m.first_block)
|
||||||
|
>>> _ = h.basicpy_singleton_op(h.basicpy_None_type)
|
||||||
|
>>> m.to_asm().strip()
|
||||||
|
'module {\n %0 = basicpy.singleton : !basicpy.NoneType\n}'
|
||||||
|
|
||||||
|
slot_object ops:
|
||||||
|
>>> m = c.new_module()
|
||||||
|
>>> h.builder.insert_block_start(m.first_block)
|
||||||
|
>>> v0 = h.basicpy_singleton_op(h.basicpy_None_type).result
|
||||||
|
>>> slot_object = h.basicpy_slot_object_make_op("foobar", v0, v0).result
|
||||||
|
>>> _ = h.basicpy_slot_object_get_op(slot_object, 0)
|
||||||
|
>>> print(m.to_asm().strip())
|
||||||
|
module {
|
||||||
|
%0 = basicpy.singleton : !basicpy.NoneType
|
||||||
|
%1 = basicpy.slot_object_make(%0, %0) -> !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||||
|
%2 = basicpy.slot_object_get %1[0] : !basicpy.SlotObject<foobar, !basicpy.NoneType, !basicpy.NoneType>
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def basicpy_singleton_op(self, singleton_type):
|
||||||
|
return self.op("basicpy.singleton", [singleton_type], [])
|
||||||
|
|
||||||
|
def basicpy_slot_object_make_op(self, class_name, *slot_values):
|
||||||
|
c = self.context
|
||||||
|
class_name_attr = c.string_attr(class_name)
|
||||||
|
object_type = self.basicpy_SlotObject_type(class_name,
|
||||||
|
*[v.type for v in slot_values])
|
||||||
|
attrs = c.dictionary_attr({"className": class_name_attr})
|
||||||
|
return self.op("basicpy.slot_object_make", [object_type], slot_values,
|
||||||
|
attrs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import doctest
|
||||||
|
doctest.testmod()
|
|
@ -2,6 +2,7 @@
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from npcomp.dialect import Basicpy
|
||||||
from npcomp.native.mlir import ir
|
from npcomp.native.mlir import ir
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -10,7 +11,7 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class DialectHelper(ir.DialectHelper):
|
class DialectHelper(Basicpy.DialectHelper):
|
||||||
r"""Dialect helper.
|
r"""Dialect helper.
|
||||||
|
|
||||||
>>> c = ir.MLIRContext()
|
>>> c = ir.MLIRContext()
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir_ir.h"
|
#include "mlir_ir.h"
|
||||||
|
#include "native.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -117,19 +118,10 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Python only classes
|
// PyDialectHelper
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Helper for creating (possibly dialect specific) IR objects. This class
|
void PyDialectHelper::bind(py::module m) {
|
||||||
/// is intended to be subclassed on the Python side (possibly with multiple
|
|
||||||
/// inheritance) to provide Python level APIs for custom dialects. The base
|
|
||||||
/// class contains helpers for std types and ops.
|
|
||||||
class PyDialectHelper {
|
|
||||||
public:
|
|
||||||
PyDialectHelper(std::shared_ptr<PyContext> context)
|
|
||||||
: pyOpBuilder(*context), context(std::move(context)) {
|
|
||||||
}
|
|
||||||
static void bind(py::module m) {
|
|
||||||
py::class_<PyDialectHelper>(m, "DialectHelper")
|
py::class_<PyDialectHelper>(m, "DialectHelper")
|
||||||
.def(py::init<std::shared_ptr<PyContext>>())
|
.def(py::init<std::shared_ptr<PyContext>>())
|
||||||
.def_property_readonly("builder",
|
.def_property_readonly("builder",
|
||||||
|
@ -179,8 +171,8 @@ public:
|
||||||
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
Location loc = UnknownLoc::get(opBuilder.getContext());
|
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||||
// TODO: Add function and arg/result attributes.
|
// TODO: Add function and arg/result attributes.
|
||||||
FuncOp op = opBuilder.create<FuncOp>(
|
FuncOp op =
|
||||||
loc, StringRef(name), functionType,
|
opBuilder.create<FuncOp>(loc, StringRef(name), functionType,
|
||||||
/*attrs=*/ArrayRef<NamedAttribute>());
|
/*attrs=*/ArrayRef<NamedAttribute>());
|
||||||
if (createEntryBlock) {
|
if (createEntryBlock) {
|
||||||
Block *entryBlock = new Block();
|
Block *entryBlock = new Block();
|
||||||
|
@ -210,10 +202,10 @@ public:
|
||||||
return PyType(IntegerType::get(width, &self.context->context));
|
return PyType(IntegerType::get(width, &self.context->context));
|
||||||
},
|
},
|
||||||
py::arg("width") = 32)
|
py::arg("width") = 32)
|
||||||
.def_property_readonly(
|
.def_property_readonly("i1_type",
|
||||||
"i1_type",
|
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(IntegerType::get(1, &self.context->context));
|
return PyType(
|
||||||
|
IntegerType::get(1, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"i16_type",
|
"i16_type",
|
||||||
|
@ -230,17 +222,15 @@ public:
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(IntegerType::get(64, &self.context->context));
|
return PyType(IntegerType::get(64, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly("f32_type",
|
||||||
"f32_type",
|
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(
|
return PyType(FloatType::get(
|
||||||
FloatType::get(StandardTypes::F32, &self.context->context));
|
StandardTypes::F32, &self.context->context));
|
||||||
})
|
})
|
||||||
.def_property_readonly(
|
.def_property_readonly("f64_type",
|
||||||
"f64_type",
|
|
||||||
[](PyDialectHelper &self) {
|
[](PyDialectHelper &self) {
|
||||||
return PyType(
|
return PyType(FloatType::get(
|
||||||
FloatType::get(StandardTypes::F64, &self.context->context));
|
StandardTypes::F64, &self.context->context));
|
||||||
})
|
})
|
||||||
.def("tensor_type",
|
.def("tensor_type",
|
||||||
[](PyDialectHelper &self, PyType elementType,
|
[](PyDialectHelper &self, PyType elementType,
|
||||||
|
@ -270,10 +260,7 @@ public:
|
||||||
return PyType(FunctionType::get(inputTypes, resultTypes,
|
return PyType(FunctionType::get(inputTypes, resultTypes,
|
||||||
&self.context->context));
|
&self.context->context));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
PyOpBuilder pyOpBuilder;
|
|
||||||
std::shared_ptr<PyContext> context;
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Module initialization
|
// Module initialization
|
||||||
|
@ -414,6 +401,15 @@ void PyBaseOperation::bind(py::module m) {
|
||||||
op->result_end());
|
op->result_end());
|
||||||
return results;
|
return results;
|
||||||
})
|
})
|
||||||
|
.def_property_readonly("result",
|
||||||
|
[](PyBaseOperation &self) -> PyValue {
|
||||||
|
auto *op = self.getOperation();
|
||||||
|
if (op->getNumResults() != 1) {
|
||||||
|
throw py::raiseValueError(
|
||||||
|
"Operation does not have 1 result");
|
||||||
|
}
|
||||||
|
return op->getOpResult(0);
|
||||||
|
})
|
||||||
.def("region",
|
.def("region",
|
||||||
[](PyBaseOperation &self, int index) {
|
[](PyBaseOperation &self, int index) {
|
||||||
auto *op = self.getOperation();
|
auto *op = self.getOperation();
|
||||||
|
@ -672,7 +668,10 @@ void PyType::bind(py::module m) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void PyValue::bind(py::module m) {
|
void PyValue::bind(py::module m) {
|
||||||
py::class_<PyValue>(m, "Value").def("__repr__", [](PyValue &self) {
|
py::class_<PyValue>(m, "Value")
|
||||||
|
.def_property_readonly(
|
||||||
|
"type", [](PyValue &self) -> PyType { return self.value.getType(); })
|
||||||
|
.def("__repr__", [](PyValue &self) {
|
||||||
std::string res;
|
std::string res;
|
||||||
llvm::raw_string_ostream os(res);
|
llvm::raw_string_ostream os(res);
|
||||||
os << self.value;
|
os << self.value;
|
||||||
|
|
|
@ -153,6 +153,25 @@ private:
|
||||||
OpBuilder builder;
|
OpBuilder builder;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Custom types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Helper for creating (possibly dialect specific) IR objects. This class
|
||||||
|
/// is intended to be subclassed on the Python side (possibly with multiple
|
||||||
|
/// inheritance) to provide Python level APIs for custom dialects. The base
|
||||||
|
/// class contains helpers for std types and ops.
|
||||||
|
class PyDialectHelper {
|
||||||
|
public:
|
||||||
|
PyDialectHelper(std::shared_ptr<PyContext> context)
|
||||||
|
: pyOpBuilder(*context), context(std::move(context)) {}
|
||||||
|
static void bind(py::module m);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
PyOpBuilder pyOpBuilder;
|
||||||
|
std::shared_ptr<PyContext> context;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // NPCOMP_PYTHON_MLIR_IR_H
|
#endif // NPCOMP_PYTHON_MLIR_IR_H
|
||||||
|
|
|
@ -9,19 +9,15 @@
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "native.h"
|
||||||
#include "pybind_utils.h"
|
#include "pybind_utils.h"
|
||||||
|
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
void defineMlirIrModule(py::module m);
|
|
||||||
|
|
||||||
namespace npcomp {
|
namespace npcomp {
|
||||||
namespace python {
|
namespace python {
|
||||||
|
|
||||||
// Externs
|
|
||||||
bool npcompMlirInitialize();
|
|
||||||
|
|
||||||
void defineLLVMModule(pybind11::module m) {
|
void defineLLVMModule(pybind11::module m) {
|
||||||
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
|
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
|
||||||
m.def("add_option",
|
m.def("add_option",
|
||||||
|
@ -66,6 +62,9 @@ PYBIND11_MODULE(native, m) {
|
||||||
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
||||||
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
auto mlir_ir_m = mlir_m.def_submodule("ir");
|
||||||
defineMlirIrModule(mlir_ir_m);
|
defineMlirIrModule(mlir_ir_m);
|
||||||
|
|
||||||
|
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
||||||
|
defineNpcompDialect(npcomp_dialect);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace python
|
} // namespace python
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===- dialect.h - Module registrations -----------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_PYTHON_NATIVE_H
|
||||||
|
#define NPCOMP_PYTHON_NATIVE_H
|
||||||
|
|
||||||
|
#include "pybind_utils.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
void defineMlirIrModule(py::module m);
|
||||||
|
|
||||||
|
namespace npcomp {
|
||||||
|
namespace python {
|
||||||
|
|
||||||
|
bool npcompMlirInitialize();
|
||||||
|
void defineNpcompDialect(py::module m);
|
||||||
|
|
||||||
|
} // namespace python
|
||||||
|
} // namespace npcomp
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_PYTHON_NATIVE_H
|
|
@ -0,0 +1,76 @@
|
||||||
|
//===- npcomp_dialect.cpp - Custom dialect classes ------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir_ir.h"
|
||||||
|
#include "native.h"
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/Basicpy/BasicpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/BasicpyOps.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
|
||||||
|
class BasicpyDialectHelper : public PyDialectHelper {
|
||||||
|
public:
|
||||||
|
using PyDialectHelper::PyDialectHelper;
|
||||||
|
static void bind(py::module m) {
|
||||||
|
py::class_<BasicpyDialectHelper, PyDialectHelper>(m, "BasicpyDialectHelper")
|
||||||
|
.def(py::init<std::shared_ptr<PyContext>>())
|
||||||
|
.def_property_readonly("basicpy_None_type",
|
||||||
|
[](BasicpyDialectHelper &self) -> PyType {
|
||||||
|
return Basicpy::NoneType::get(
|
||||||
|
&self.context->context);
|
||||||
|
})
|
||||||
|
.def_property_readonly("basicpy_Ellipsis_type",
|
||||||
|
[](BasicpyDialectHelper &self) -> PyType {
|
||||||
|
return Basicpy::EllipsisType::get(
|
||||||
|
&self.context->context);
|
||||||
|
})
|
||||||
|
.def("basicpy_SlotObject_type",
|
||||||
|
[](BasicpyDialectHelper &self, std::string className,
|
||||||
|
py::args pySlotTypes) -> PyType {
|
||||||
|
SmallVector<Type, 4> slotTypes;
|
||||||
|
for (auto pySlotType : pySlotTypes) {
|
||||||
|
slotTypes.push_back(pySlotType.cast<PyType>());
|
||||||
|
}
|
||||||
|
auto classNameAttr =
|
||||||
|
StringAttr::get(className, &self.context->context);
|
||||||
|
return Basicpy::SlotObjectType::get(classNameAttr, slotTypes);
|
||||||
|
},
|
||||||
|
py::arg("className"))
|
||||||
|
.def("basicpy_slot_object_get_op",
|
||||||
|
[](BasicpyDialectHelper &self, PyValue slotObject,
|
||||||
|
unsigned index) -> PyOperationRef {
|
||||||
|
auto slotObjectType = slotObject.value.getType()
|
||||||
|
.dyn_cast<Basicpy::SlotObjectType>();
|
||||||
|
if (!slotObjectType) {
|
||||||
|
throw py::raiseValueError("Operand must be a SlotObject");
|
||||||
|
}
|
||||||
|
if (index >= slotObjectType.getSlotCount()) {
|
||||||
|
throw py::raiseValueError("Out of range slot index");
|
||||||
|
}
|
||||||
|
auto resultType = slotObjectType.getSlotTypes()[index];
|
||||||
|
auto indexAttr = IntegerAttr::get(
|
||||||
|
IndexType::get(&self.context->context), index);
|
||||||
|
OpBuilder &opBuilder = self.pyOpBuilder.getBuilder(true);
|
||||||
|
Location loc = UnknownLoc::get(opBuilder.getContext());
|
||||||
|
auto op = opBuilder.create<Basicpy::SlotObjectGetOp>(
|
||||||
|
loc, resultType, slotObject, indexAttr);
|
||||||
|
return op.getOperation();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
using namespace ::mlir::NPCOMP;
|
||||||
|
|
||||||
|
void mlir::npcomp::python::defineNpcompDialect(py::module m) {
|
||||||
|
BasicpyDialectHelper::bind(m);
|
||||||
|
}
|
|
@ -212,10 +212,8 @@ class GenericArrayFuncEmitter(FuncEmitter):
|
||||||
|
|
||||||
def emit(self, request: EmissionRequest):
|
def emit(self, request: EmissionRequest):
|
||||||
h = request.dialect_helper
|
h = request.dialect_helper
|
||||||
op_result_types = [h.tensor_type(h.numpy_any_dtype)
|
op_result_types = [h.tensor_type(h.numpy_any_dtype)] * self._nresults
|
||||||
] * self._nresults
|
op = h.op(self._op_name, op_result_types, request.input_ssa_values)
|
||||||
op = h.op(self._op_name, op_result_types,
|
|
||||||
request.input_ssa_values)
|
|
||||||
return op.results
|
return op.results
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue