2020-07-11 08:36:32 +08:00
|
|
|
//===- PythonModule.cpp - RefJIT python bindings --------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "npcomp/Backend/RefJIT/PythonModule.h"
|
|
|
|
|
|
|
|
#include "pybind11/numpy.h"
|
|
|
|
|
2020-11-11 13:38:13 +08:00
|
|
|
#include "mlir/CAPI/IR.h"
|
|
|
|
#include "mlir/CAPI/Pass.h"
|
2020-10-08 07:11:41 +08:00
|
|
|
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
2020-07-11 08:36:32 +08:00
|
|
|
|
|
|
|
using llvm::SmallVector;
|
|
|
|
using llvm::StringRef;
|
2020-07-11 12:51:03 +08:00
|
|
|
using llvm::Twine;
|
2020-07-11 08:36:32 +08:00
|
|
|
|
|
|
|
// Make namespaces consistent.
|
2020-10-08 09:51:24 +08:00
|
|
|
using refback::JITModule;
|
2020-10-08 08:12:52 +08:00
|
|
|
using refbackrt::Ref;
|
|
|
|
using refbackrt::Tensor;
|
2021-03-11 07:39:26 +08:00
|
|
|
using refbackrt::RtValue;
|
2020-07-11 08:36:32 +08:00
|
|
|
|
|
|
|
template <typename T>
|
2020-07-11 12:51:03 +08:00
|
|
|
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
|
|
|
if (LLVM_LIKELY(expected))
|
2020-07-11 08:36:32 +08:00
|
|
|
return std::move(*expected);
|
2020-07-11 12:51:03 +08:00
|
|
|
|
2020-07-11 08:36:32 +08:00
|
|
|
std::string errorMessage;
|
|
|
|
llvm::raw_string_ostream os(errorMessage);
|
2020-07-11 12:51:03 +08:00
|
|
|
llvm::logAllUnhandledErrors(expected.takeError(), os, banner);
|
2020-07-11 08:36:32 +08:00
|
|
|
os.flush();
|
|
|
|
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
|
|
|
|
}
|
|
|
|
|
2020-10-08 08:12:52 +08:00
|
|
|
static refbackrt::ElementType
|
2020-07-11 08:36:32 +08:00
|
|
|
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
|
|
|
|
if (format == "f")
|
2020-10-08 08:12:52 +08:00
|
|
|
return refbackrt::ElementType::F32;
|
2020-07-11 08:36:32 +08:00
|
|
|
|
|
|
|
std::string message("unsupported buffer format: ");
|
|
|
|
message.append(format);
|
|
|
|
throw py::raiseValueError(message);
|
|
|
|
}
|
|
|
|
|
|
|
|
static Ref<Tensor> copyBufferToTensor(py::buffer buffer) {
|
|
|
|
// Request a C contiguous view as that is what Tensor accepts now (no strides
|
|
|
|
// or non row-major layout).
|
|
|
|
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
|
|
|
std::unique_ptr<Py_buffer> view(new Py_buffer());
|
|
|
|
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
|
|
|
|
throw py::error_already_set();
|
|
|
|
}
|
|
|
|
py::buffer_info info(view.release());
|
|
|
|
auto elementType = mapBufferFormatToElementType(info.format, info.itemsize);
|
|
|
|
|
|
|
|
// TODO: Switch Tensor extents to ssize_t for efficiency.
|
|
|
|
SmallVector<std::int32_t, 4> extents(info.shape.begin(), info.shape.end());
|
|
|
|
return Tensor::create(
|
2020-10-08 08:12:52 +08:00
|
|
|
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
2020-07-11 08:36:32 +08:00
|
|
|
elementType, info.ptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
py::array wrapTensorAsArray(Ref<Tensor> tensor) {
|
|
|
|
auto pyTensor = py::cast(tensor);
|
|
|
|
auto extents = tensor->getExtents();
|
|
|
|
// TODO: Switch Tensor extents to ssize_t for efficiency.
|
|
|
|
std::vector<ssize_t> shape(extents.data(), extents.data() + extents.size());
|
|
|
|
|
|
|
|
const char *format;
|
|
|
|
switch (tensor->getElementType()) {
|
2020-10-08 08:12:52 +08:00
|
|
|
case refbackrt::ElementType::F32:
|
2020-07-11 08:36:32 +08:00
|
|
|
format = "f";
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
throw py::raiseValueError("unsupported tensor element type");
|
|
|
|
}
|
|
|
|
|
|
|
|
return py::array(py::dtype(format), shape, tensor->getData(),
|
|
|
|
/*base=*/std::move(pyTensor));
|
|
|
|
}
|
|
|
|
|
2020-11-11 13:38:13 +08:00
|
|
|
void npcomp::python::defineBackendRefJitModule(py::module &m) {
|
|
|
|
m.def("build_backend_compilation_pipeline", [](MlirPassManager capiPm) {
|
|
|
|
mlir::PassManager *pm = unwrap(capiPm);
|
|
|
|
JITModule::buildBackendCompilationPipeline(*pm);
|
2020-07-11 12:51:03 +08:00
|
|
|
});
|
2020-07-11 08:36:32 +08:00
|
|
|
py::class_<JITModule>(m, "JITModule")
|
2020-08-28 05:47:49 +08:00
|
|
|
.def_static(
|
|
|
|
"from_compiled_module",
|
2020-11-11 13:38:13 +08:00
|
|
|
[](MlirModule capiModule, std::vector<std::string> pySharedLibs)
|
2020-08-28 05:47:49 +08:00
|
|
|
-> std::unique_ptr<JITModule> {
|
|
|
|
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
|
|
|
pySharedLibs.end());
|
2020-11-11 13:38:13 +08:00
|
|
|
auto module = unwrap(capiModule);
|
|
|
|
auto jitModule =
|
|
|
|
checkError(JITModule::fromCompiledModule(module, sharedLibs),
|
|
|
|
"error creating JITModule: ");
|
2020-08-28 05:47:49 +08:00
|
|
|
return jitModule;
|
|
|
|
},
|
|
|
|
py::arg("module"), py::arg("shared_libs"))
|
|
|
|
.def(
|
|
|
|
"invoke",
|
|
|
|
[](JITModule &self, std::string functionName,
|
|
|
|
std::vector<py::buffer> inputs) {
|
|
|
|
// Prepare inputs.
|
2021-03-11 07:39:26 +08:00
|
|
|
llvm::SmallVector<RtValue, 4> inputValues;
|
|
|
|
inputValues.reserve(inputs.size());
|
2020-08-28 05:47:49 +08:00
|
|
|
for (py::buffer &inputBuffer : inputs) {
|
2021-03-11 07:39:26 +08:00
|
|
|
inputValues.push_back(copyBufferToTensor(inputBuffer));
|
2020-08-28 05:47:49 +08:00
|
|
|
}
|
|
|
|
|
2021-03-11 07:39:26 +08:00
|
|
|
auto outputs = checkError(self.invoke(functionName, inputValues),
|
2020-08-28 05:47:49 +08:00
|
|
|
"error invoking JIT function: ");
|
|
|
|
std::vector<py::array> outputArrays;
|
|
|
|
outputArrays.reserve(outputs.size());
|
2021-03-11 07:39:26 +08:00
|
|
|
for (RtValue &outputTensor : outputs) {
|
|
|
|
outputArrays.push_back(wrapTensorAsArray(outputTensor.toTensor()));
|
2020-08-28 05:47:49 +08:00
|
|
|
}
|
|
|
|
return outputArrays;
|
|
|
|
},
|
|
|
|
py::arg("function_name"), py::arg("inputs"));
|
2020-07-11 08:36:32 +08:00
|
|
|
|
|
|
|
// A Ref<Tensor> needs to be bound because we use it as a base for the
|
|
|
|
// ndarray (the array retains a reference to it). Users should not encounter
|
|
|
|
// this unless if they go mucking through the array internals.
|
|
|
|
py::class_<Ref<Tensor>>(m, "TensorRef");
|
|
|
|
}
|