Add MLIRContext.dense_elements_attr to create an attribute from a python buffer/array.

pull/1/head
Stella Laurenzo 2020-05-08 17:36:07 -07:00
parent a91b0bfbe1
commit 8ae71a9551
3 changed files with 90 additions and 10 deletions

2
.gitignore vendored
View File

@ -4,3 +4,5 @@
build
build-mlir
install-mlir
__pycache__

View File

@ -2,6 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
from npcomp.dialect import Basicpy
from _npcomp.mlir import ir
@ -35,6 +36,16 @@ class DialectHelper(Basicpy.DialectHelper):
}
}
DenseElementsAttrs:
>>> c.dense_elements_attr(np.asarray([1, 2, 3, 4]))
dense<[1, 2, 3, 4]> : tensor<4xsi64>
>>> c.dense_elements_attr(np.asarray([[1, 2], [3, 4]]))
dense<[[1, 2], [3, 4]]> : tensor<2x2xsi64>
>>> c.dense_elements_attr(np.asarray([[1., 2.], [3., 4.]]))
dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf64>
>>> c.dense_elements_attr(np.asarray([[1., 2.], [3., 4.]], dtype=np.float32))
dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>
Types:
>>> t = DialectHelper(ir.MLIRContext())
>>> t.numpy_any_dtype

View File

@ -76,6 +76,58 @@ using PyBlockList = PyIpListWrapper<Region::BlockListType, PyBlockRef>;
template class PyIpListWrapper<Block::OpListType, PyOperationRef>;
using PyOperationList = PyIpListWrapper<Block::OpListType, PyOperationRef>;
//===----------------------------------------------------------------------===//
// Conversions
//===----------------------------------------------------------------------===//
Type mapBufferFormatToType(MLIRContext *context, const std::string &format,
py::ssize_t itemSize) {
// Floating point formats.
if (format == "f")
return FloatType::getF32(context);
if (format == "d")
return FloatType::getF64(context);
if (format == "D")
return ComplexType::get(FloatType::getF64(context));
// Signed integer formats.
if (format == "b" || format == "h" || format == "i" || format == "l" ||
format == "L") {
unsigned width = itemSize * 8;
return IntegerType::get(width, IntegerType::SignednessSemantics::Signed,
context);
}
// Unsigned integer format.
if (format == "B" || format == "H" || format == "I" || format == "k" ||
format == "K") {
unsigned width = itemSize * 8;
return IntegerType::get(width, IntegerType::SignednessSemantics::Unsigned,
context);
}
return Type();
}
/// Creates a DenseElementsAttr from a python buffer which must have been
/// requested to be C-Contiguous.
Attribute createDenseElementsAttrFromBuffer(MLIRContext *context,
py::buffer_info &array) {
Type elementType =
mapBufferFormatToType(context, array.format, array.itemsize);
if (!elementType) {
throw py::raiseValueError(
"Unsupported buffer/array type for conversion to DenseElementsAttr");
}
SmallVector<int64_t, 4> shape(array.shape.begin(),
array.shape.begin() + array.ndim);
RankedTensorType type = RankedTensorType::get(shape, elementType);
const char *rawBufferPtr = reinterpret_cast<const char *>(array.ptr);
ArrayRef<char> rawBuffer(rawBufferPtr, array.size * array.itemsize);
return DenseElementsAttr::getFromRawBuffer(type, rawBuffer, false);
}
//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//
@ -359,16 +411,31 @@ void PyContext::bind(py::module m) {
[](PyContext &self, const std::string &s) -> PyAttribute {
return FlatSymbolRefAttr::get(s, &self.context);
})
.def("dictionary_attr", [](PyContext &self, py::dict d) -> PyAttribute {
SmallVector<NamedAttribute, 4> attrs;
for (auto &it : d) {
auto key = it.first.cast<std::string>();
auto value = it.second.cast<PyAttribute>();
auto keyIdent = Identifier::get(key, &self.context);
attrs.emplace_back(keyIdent, value.attr);
}
return DictionaryAttr::get(attrs, &self.context);
});
.def("dictionary_attr",
[](PyContext &self, py::dict d) -> PyAttribute {
SmallVector<NamedAttribute, 4> attrs;
for (auto &it : d) {
auto key = it.first.cast<std::string>();
auto value = it.second.cast<PyAttribute>();
auto keyIdent = Identifier::get(key, &self.context);
attrs.emplace_back(keyIdent, value.attr);
}
return DictionaryAttr::get(attrs, &self.context);
})
.def("dense_elements_attr",
[](PyContext &self, py::buffer array) -> PyAttribute {
// Request a contiguous view.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
Py_buffer *view = new Py_buffer();
if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
delete view;
throw py::error_already_set();
}
py::buffer_info array_info(view);
return createDenseElementsAttrFromBuffer(&self.context,
array_info);
},
py::arg("array"));
}
PyModuleOp PyContext::parseAsm(const std::string &asm_text) {