diff --git a/CMakeLists.txt b/CMakeLists.txt index 6391a7a2d..b77fffb05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") #------------------------------------------------------------------------------- option(NPCOMP_ENABLE_IREE "Enables the IREE backend (must configure location via IREE_DIR)." OFF) +option(NPCOMP_ENABLE_REFJIT "Enables the reference JIT backend." ON) set(NPCOMP_IREE_SRCDIR "" CACHE STRING "If building IREE, then setting this elects to build from a source directory (versus installed package)") #------------------------------------------------------------------------------- @@ -62,6 +63,15 @@ link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) set(NPCOMP_TABLEGEN_ARGS "") +#------------------------------------------------------------------------------- +# Optional feature selection +#------------------------------------------------------------------------------- + +if(NPCOMP_ENABLE_REFJIT) + add_compile_definitions(NPCOMP_ENABLE_REFJIT) + message(STATUS "Reference JIT backend enabled") +endif() + #------------------------------------------------------------------------------- # IREE configuration #------------------------------------------------------------------------------- diff --git a/include/npcomp/Backend/RefJIT/PythonModule.h b/include/npcomp/Backend/RefJIT/PythonModule.h new file mode 100644 index 000000000..b2e62737b --- /dev/null +++ b/include/npcomp/Backend/RefJIT/PythonModule.h @@ -0,0 +1,23 @@ +//===- PythonModule.h - IREE python bindings ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_BACKEND_IREE_PYTHON_MODULE_H +#define NPCOMP_BACKEND_IREE_PYTHON_MODULE_H + +#include "npcomp/Python/PybindUtils.h" + +namespace npcomp { +namespace python { + +/// Defines an "refjit" module with backend support definitions. +void defineBackendRefJitModule(py::module m); + +} // namespace python +} // namespace npcomp + +#endif // NPCOMP_BACKEND_IREE_PYTHON_MODULE_H diff --git a/lib/Backend/RefJIT/CMakeLists.txt b/lib/Backend/RefJIT/CMakeLists.txt new file mode 100644 index 000000000..90512d4b7 --- /dev/null +++ b/lib/Backend/RefJIT/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# NPCOMPBackendRefJITPythonModule +################################################################################ + +include(NpcompPython) + +set(PYBIND_SOURCES + PythonModule.cpp +) +add_library(NPCOMPBackendRefJITPythonModule + ${PYBIND_SOURCES} +) + +target_link_libraries(NPCOMPBackendRefJITPythonModule + MLIRExecutionEngine + MLIRTargetLLVMIR + + NPCOMPPythonCommon + NPCOMPJITRuntime +) + +npcomp_python_target_compile_options(NPCOMPBackendRefJITPythonModule) diff --git a/lib/Backend/RefJIT/PythonModule.cpp b/lib/Backend/RefJIT/PythonModule.cpp new file mode 100644 index 000000000..cb925ce1c --- /dev/null +++ b/lib/Backend/RefJIT/PythonModule.cpp @@ -0,0 +1,123 @@ +//===- PythonModule.cpp - RefJIT python bindings --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Backend/RefJIT/PythonModule.h" + +#include "pybind11/numpy.h" + +#include "npcomp/JITRuntime/JITModule.h" +#include "npcomp/Python/MlirIr.h" + +using llvm::SmallVector; +using llvm::StringRef; + +// Make namespaces consistent. +using mlir::PyModuleOp; +using npcomp::JITModule; +using npcomprt::Ref; +using npcomprt::Tensor; + +template +static T checkError(llvm::Expected expected, const char *messagePrefix) { + if (expected) + return std::move(*expected); + // TODO: FIXME: Figure out how these errors work. This crashes at runtime. + auto error = expected.takeError(); + std::string errorMessage; + llvm::raw_string_ostream os(errorMessage); + os << messagePrefix << error; + os.flush(); + throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str()); +} + +static npcomprt::ElementType +mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) { + if (format == "f") + return npcomprt::ElementType::F32; + + std::string message("unsupported buffer format: "); + message.append(format); + throw py::raiseValueError(message); +} + +static Ref copyBufferToTensor(py::buffer buffer) { + // Request a C contiguous view as that is what Tensor accepts now (no strides + // or non row-major layout). + int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; + std::unique_ptr view(new Py_buffer()); + if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { + throw py::error_already_set(); + } + py::buffer_info info(view.release()); + auto elementType = mapBufferFormatToElementType(info.format, info.itemsize); + + // TODO: Switch Tensor extents to ssize_t for efficiency. + SmallVector extents(info.shape.begin(), info.shape.end()); + return Tensor::create( + npcomprt::ArrayRef(extents.data(), extents.size()), + elementType, info.ptr); +} + +py::array wrapTensorAsArray(Ref tensor) { + auto pyTensor = py::cast(tensor); + auto extents = tensor->getExtents(); + // TODO: Switch Tensor extents to ssize_t for efficiency. + std::vector shape(extents.data(), extents.data() + extents.size()); + + const char *format; + switch (tensor->getElementType()) { + case npcomprt::ElementType::F32: + format = "f"; + break; + default: + throw py::raiseValueError("unsupported tensor element type"); + } + + return py::array(py::dtype(format), shape, tensor->getData(), + /*base=*/std::move(pyTensor)); +} + +void npcomp::python::defineBackendRefJitModule(py::module m) { + py::class_(m, "JITModule") + .def_static("from_mlir", + [](PyModuleOp module, std::vector pySharedLibs) + -> std::unique_ptr { + SmallVector sharedLibs(pySharedLibs.begin(), + pySharedLibs.end()); + auto jitModule = checkError( + JITModule::fromMLIR(module.moduleOp, sharedLibs), + "error creating JITModule"); + return jitModule; + }, + py::arg("module"), py::arg("shared_libs")) + .def("invoke", + [](JITModule &self, std::string functionName, + std::vector inputs) { + // Prepare inputs. + llvm::SmallVector, 4> inputTensors; + inputTensors.reserve(inputs.size()); + for (py::buffer &inputBuffer : inputs) { + inputTensors.push_back(copyBufferToTensor(inputBuffer)); + } + + auto outputs = checkError(self.invoke(functionName, inputTensors), + "error invoking JIT function: "); + std::vector outputArrays; + outputArrays.reserve(outputs.size()); + for (Ref &outputTensor : outputs) { + outputArrays.push_back(wrapTensorAsArray(outputTensor)); + } + return outputArrays; + }, + py::arg("function_name"), py::arg("inputs")); + + // A Ref needs to be bound because we use it as a base for the + // ndarray (the array retains a reference to it). Users should not encounter + // this unless if they go mucking through the array internals. + py::class_>(m, "TensorRef"); +} diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index da3ffe120..65de4312f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -6,6 +6,9 @@ add_subdirectory(Python) add_subdirectory(Typing) add_subdirectory(runtime) +if(NPCOMP_ENABLE_REFJIT) + add_subdirectory(Backend/RefJIT) +endif() if(NPCOMP_ENABLE_IREE) add_subdirectory(Backend/IREE) endif() diff --git a/pytest/Backend/RefJIT/simple_invoke.py b/pytest/Backend/RefJIT/simple_invoke.py new file mode 100644 index 000000000..8e536d4b1 --- /dev/null +++ b/pytest/Backend/RefJIT/simple_invoke.py @@ -0,0 +1,36 @@ +# RUN: %PYTHON %s + +import numpy as np + +from npcomp.compiler.backend import refjit +from npcomp.compiler.frontend import * +from npcomp.compiler import logging +from npcomp.compiler import test_config +from npcomp.compiler.target import * + +# TODO: This should all exist in a high level API somewhere. +from _npcomp import mlir + +logging.enable() + + +def compile_function(f): + fe = ImportFrontend(config=test_config.create_test_config( + target_factory=GenericTarget32)) + fe.import_global_function(f) + compiler = refjit.CompilerBackend() + vm_blob = compiler.compile(fe.ir_module) + loaded_m = compiler.load(vm_blob) + return loaded_m[f.__name__] + + +global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0], + (2, 1))) + +a = np.asarray([1.0, 2.0], dtype=np.float32) +b = np.asarray([3.0, 4.0], dtype=np.float32) + + +@compile_function +def global_add(): + return np.add(a, np.add(b, a)) diff --git a/python/npcomp/compiler/backend/refjit.py b/python/npcomp/compiler/backend/refjit.py new file mode 100644 index 000000000..220371d59 --- /dev/null +++ b/python/npcomp/compiler/backend/refjit.py @@ -0,0 +1,79 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from _npcomp import mlir +from npcomp.compiler import logging + +__all__ = [ + "is_enabled", + "CompilerBackend", +] + +FRONTEND_PASSES = ( + "npcomp-cpa-type-inference", + "numpy-public-functions-to-tensor", + "convert-numpy-to-tcf", + "canonicalize", + "convert-scf-to-std", +) + +_refjit = None + + +def _get_refjit(): + """Dynamically resolves the refjit backend native module.""" + global _refjit + if _refjit is not None: + return _refjit + try: + from _npcomp.backend import refjit as imported_refjit + except ImportError: + raise ImportError( + "The npcomp native module was not compiled with refjit support") + _refjit = imported_refjit + return _refjit + + +def is_enabled() -> bool: + """Returns whether the backend is enabled for the current build.""" + try: + _get_refjit() + return True + except ImportError: + return False + + +class CompilerBackend: + """Main entry-point for the backend.""" + + def __init__(self): + super().__init__() + self._refjit = _get_refjit() + self._debug = logging.debug_enabled() + + def compile(self, imported_ir_module: mlir.ir.ModuleOp): + """Compiles an imported module. + + Args: + imported_ir_module: The MLIR module as imported from the ImportFrontend. + Returns: + An opaque, backend specific module object that can be passed to load. + The object may actually be something more specific to the backend (i.e. + for IREE, it is a serialized VM flatbuffer) but the contract is that + it is operated on by methods on this class. + """ + pm = mlir.passes.PassManager(imported_ir_module.context) + pm.addPassPipelines(*FRONTEND_PASSES) + pm.run(imported_ir_module) + if self._debug: + logging.debug("Frontend IR:{}", imported_ir_module.to_asm()) + + jit_module = self._refjit.JITModule.from_mlir(imported_ir_module, []) + return jit_module + + def load(self, jit_module): + """Loads a compiled artifact into the runtime. + + Since this is a JIT instead of an AOT compiler, + """ diff --git a/python_native/CMakeLists.txt b/python_native/CMakeLists.txt index fe4eba725..5b53c1a69 100644 --- a/python_native/CMakeLists.txt +++ b/python_native/CMakeLists.txt @@ -21,6 +21,10 @@ if(NPCOMP_ENABLE_IREE) list(APPEND NPCOMP_PYEXT_LIBADD NPCOMPBackendIREEPythonModule) endif() +if(NPCOMP_ENABLE_REFJIT) + list(APPEND NPCOMP_PYEXT_LIBADD NPCOMPBackendRefJITPythonModule) +endif() + # TODO(laurenzo): Add a config setting to control this. # set(NPCOMP_PYEXT_LINK_MODE MODULE) # set(NPCOMP_PYEXT_LIBADD "") diff --git a/python_native/NpcompModule.cpp b/python_native/NpcompModule.cpp index 21612bfed..c460c9154 100644 --- a/python_native/NpcompModule.cpp +++ b/python_native/NpcompModule.cpp @@ -14,6 +14,10 @@ #include "npcomp/Python/PybindUtils.h" #include "llvm/Support/CommandLine.h" +#ifdef NPCOMP_ENABLE_REFJIT +#include "npcomp/Backend/RefJIT/PythonModule.h" +#endif + #ifdef NPCOMP_ENABLE_IREE #include "npcomp/Backend/IREE/PythonModule.h" #endif @@ -81,6 +85,12 @@ PYBIND11_MODULE(_npcomp, m) { auto backend_m = m.def_submodule("backend", "Backend support"); (void)backend_m; +#ifdef NPCOMP_ENABLE_REFJIT + auto refjit_m = + backend_m.def_submodule("refjit", "Reference CPU Jit Backend"); + ::npcomp::python::defineBackendRefJitModule(refjit_m); +#endif + #ifdef NPCOMP_ENABLE_IREE auto iree_m = backend_m.def_submodule("iree", "IREE backend support"); defineBackendIREEModule(iree_m);