//===- PythonModule.cpp - RefJIT python bindings --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Backend/RefJIT/PythonModule.h" #include "pybind11/numpy.h" #include "npcomp/Python/MlirIr.h" #include "npcomp/Python/MlirPass.h" #include "npcomp/RefBackend/JITHelpers/JITModule.h" using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; // Make namespaces consistent. using mlir::PyModuleOp; using mlir::PyPassManager; using npcomp::JITModule; using npcomprt::Ref; using npcomprt::Tensor; template static T checkError(llvm::Expected &&expected, Twine banner = {}) { if (LLVM_LIKELY(expected)) return std::move(*expected); std::string errorMessage; llvm::raw_string_ostream os(errorMessage); llvm::logAllUnhandledErrors(expected.takeError(), os, banner); os.flush(); throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str()); } static npcomprt::ElementType mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) { if (format == "f") return npcomprt::ElementType::F32; std::string message("unsupported buffer format: "); message.append(format); throw py::raiseValueError(message); } static Ref copyBufferToTensor(py::buffer buffer) { // Request a C contiguous view as that is what Tensor accepts now (no strides // or non row-major layout). int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; std::unique_ptr view(new Py_buffer()); if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { throw py::error_already_set(); } py::buffer_info info(view.release()); auto elementType = mapBufferFormatToElementType(info.format, info.itemsize); // TODO: Switch Tensor extents to ssize_t for efficiency. SmallVector extents(info.shape.begin(), info.shape.end()); return Tensor::create( npcomprt::ArrayRef(extents.data(), extents.size()), elementType, info.ptr); } py::array wrapTensorAsArray(Ref tensor) { auto pyTensor = py::cast(tensor); auto extents = tensor->getExtents(); // TODO: Switch Tensor extents to ssize_t for efficiency. std::vector shape(extents.data(), extents.data() + extents.size()); const char *format; switch (tensor->getElementType()) { case npcomprt::ElementType::F32: format = "f"; break; default: throw py::raiseValueError("unsupported tensor element type"); } return py::array(py::dtype(format), shape, tensor->getData(), /*base=*/std::move(pyTensor)); } void npcomp::python::defineBackendRefJitModule(py::module m) { m.def("build_backend_compilation_pipeline", [](PyPassManager &pm) { JITModule::buildBackendCompilationPipeline(pm.passManager); }); py::class_(m, "JITModule") .def_static( "from_compiled_module", [](PyModuleOp module, std::vector pySharedLibs) -> std::unique_ptr { SmallVector sharedLibs(pySharedLibs.begin(), pySharedLibs.end()); auto jitModule = checkError( JITModule::fromCompiledModule(module.moduleOp, sharedLibs), "error creating JITModule: "); return jitModule; }, py::arg("module"), py::arg("shared_libs")) .def( "invoke", [](JITModule &self, std::string functionName, std::vector inputs) { // Prepare inputs. llvm::SmallVector, 4> inputTensors; inputTensors.reserve(inputs.size()); for (py::buffer &inputBuffer : inputs) { inputTensors.push_back(copyBufferToTensor(inputBuffer)); } auto outputs = checkError(self.invoke(functionName, inputTensors), "error invoking JIT function: "); std::vector outputArrays; outputArrays.reserve(outputs.size()); for (Ref &outputTensor : outputs) { outputArrays.push_back(wrapTensorAsArray(outputTensor)); } return outputArrays; }, py::arg("function_name"), py::arg("inputs")); // A Ref needs to be bound because we use it as a base for the // ndarray (the array retains a reference to it). Users should not encounter // this unless if they go mucking through the array internals. py::class_>(m, "TensorRef"); }