mirror of https://github.com/llvm/torch-mlir
Initial python plumbing to interface with the refjit backend.
parent
df0d3fcaff
commit
aea05d68d7
|
@ -27,6 +27,7 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
|
|||
#-------------------------------------------------------------------------------
|
||||
|
||||
option(NPCOMP_ENABLE_IREE "Enables the IREE backend (must configure location via IREE_DIR)." OFF)
|
||||
option(NPCOMP_ENABLE_REFJIT "Enables the reference JIT backend." ON)
|
||||
set(NPCOMP_IREE_SRCDIR "" CACHE STRING "If building IREE, then setting this elects to build from a source directory (versus installed package)")
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
@ -62,6 +63,15 @@ link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
|||
add_definitions(${LLVM_DEFINITIONS})
|
||||
set(NPCOMP_TABLEGEN_ARGS "")
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Optional feature selection
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
add_compile_definitions(NPCOMP_ENABLE_REFJIT)
|
||||
message(STATUS "Reference JIT backend enabled")
|
||||
endif()
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# IREE configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
//===- PythonModule.h - IREE python bindings ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_BACKEND_IREE_PYTHON_MODULE_H
|
||||
#define NPCOMP_BACKEND_IREE_PYTHON_MODULE_H
|
||||
|
||||
#include "npcomp/Python/PybindUtils.h"
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
/// Defines an "refjit" module with backend support definitions.
|
||||
void defineBackendRefJitModule(py::module m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
|
||||
#endif // NPCOMP_BACKEND_IREE_PYTHON_MODULE_H
|
|
@ -0,0 +1,22 @@
|
|||
################################################################################
|
||||
# NPCOMPBackendRefJITPythonModule
|
||||
################################################################################
|
||||
|
||||
include(NpcompPython)
|
||||
|
||||
set(PYBIND_SOURCES
|
||||
PythonModule.cpp
|
||||
)
|
||||
add_library(NPCOMPBackendRefJITPythonModule
|
||||
${PYBIND_SOURCES}
|
||||
)
|
||||
|
||||
target_link_libraries(NPCOMPBackendRefJITPythonModule
|
||||
MLIRExecutionEngine
|
||||
MLIRTargetLLVMIR
|
||||
|
||||
NPCOMPPythonCommon
|
||||
NPCOMPJITRuntime
|
||||
)
|
||||
|
||||
npcomp_python_target_compile_options(NPCOMPBackendRefJITPythonModule)
|
|
@ -0,0 +1,123 @@
|
|||
//===- PythonModule.cpp - RefJIT python bindings --------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Backend/RefJIT/PythonModule.h"
|
||||
|
||||
#include "pybind11/numpy.h"
|
||||
|
||||
#include "npcomp/JITRuntime/JITModule.h"
|
||||
#include "npcomp/Python/MlirIr.h"
|
||||
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
|
||||
// Make namespaces consistent.
|
||||
using mlir::PyModuleOp;
|
||||
using npcomp::JITModule;
|
||||
using npcomprt::Ref;
|
||||
using npcomprt::Tensor;
|
||||
|
||||
template <typename T>
|
||||
static T checkError(llvm::Expected<T> expected, const char *messagePrefix) {
|
||||
if (expected)
|
||||
return std::move(*expected);
|
||||
// TODO: FIXME: Figure out how these errors work. This crashes at runtime.
|
||||
auto error = expected.takeError();
|
||||
std::string errorMessage;
|
||||
llvm::raw_string_ostream os(errorMessage);
|
||||
os << messagePrefix << error;
|
||||
os.flush();
|
||||
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
|
||||
}
|
||||
|
||||
static npcomprt::ElementType
|
||||
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
|
||||
if (format == "f")
|
||||
return npcomprt::ElementType::F32;
|
||||
|
||||
std::string message("unsupported buffer format: ");
|
||||
message.append(format);
|
||||
throw py::raiseValueError(message);
|
||||
}
|
||||
|
||||
static Ref<Tensor> copyBufferToTensor(py::buffer buffer) {
|
||||
// Request a C contiguous view as that is what Tensor accepts now (no strides
|
||||
// or non row-major layout).
|
||||
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
||||
std::unique_ptr<Py_buffer> view(new Py_buffer());
|
||||
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
py::buffer_info info(view.release());
|
||||
auto elementType = mapBufferFormatToElementType(info.format, info.itemsize);
|
||||
|
||||
// TODO: Switch Tensor extents to ssize_t for efficiency.
|
||||
SmallVector<std::int32_t, 4> extents(info.shape.begin(), info.shape.end());
|
||||
return Tensor::create(
|
||||
npcomprt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
||||
elementType, info.ptr);
|
||||
}
|
||||
|
||||
py::array wrapTensorAsArray(Ref<Tensor> tensor) {
|
||||
auto pyTensor = py::cast(tensor);
|
||||
auto extents = tensor->getExtents();
|
||||
// TODO: Switch Tensor extents to ssize_t for efficiency.
|
||||
std::vector<ssize_t> shape(extents.data(), extents.data() + extents.size());
|
||||
|
||||
const char *format;
|
||||
switch (tensor->getElementType()) {
|
||||
case npcomprt::ElementType::F32:
|
||||
format = "f";
|
||||
break;
|
||||
default:
|
||||
throw py::raiseValueError("unsupported tensor element type");
|
||||
}
|
||||
|
||||
return py::array(py::dtype(format), shape, tensor->getData(),
|
||||
/*base=*/std::move(pyTensor));
|
||||
}
|
||||
|
||||
void npcomp::python::defineBackendRefJitModule(py::module m) {
|
||||
py::class_<JITModule>(m, "JITModule")
|
||||
.def_static("from_mlir",
|
||||
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
|
||||
-> std::unique_ptr<JITModule> {
|
||||
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
||||
pySharedLibs.end());
|
||||
auto jitModule = checkError(
|
||||
JITModule::fromMLIR(module.moduleOp, sharedLibs),
|
||||
"error creating JITModule");
|
||||
return jitModule;
|
||||
},
|
||||
py::arg("module"), py::arg("shared_libs"))
|
||||
.def("invoke",
|
||||
[](JITModule &self, std::string functionName,
|
||||
std::vector<py::buffer> inputs) {
|
||||
// Prepare inputs.
|
||||
llvm::SmallVector<Ref<Tensor>, 4> inputTensors;
|
||||
inputTensors.reserve(inputs.size());
|
||||
for (py::buffer &inputBuffer : inputs) {
|
||||
inputTensors.push_back(copyBufferToTensor(inputBuffer));
|
||||
}
|
||||
|
||||
auto outputs = checkError(self.invoke(functionName, inputTensors),
|
||||
"error invoking JIT function: ");
|
||||
std::vector<py::array> outputArrays;
|
||||
outputArrays.reserve(outputs.size());
|
||||
for (Ref<Tensor> &outputTensor : outputs) {
|
||||
outputArrays.push_back(wrapTensorAsArray(outputTensor));
|
||||
}
|
||||
return outputArrays;
|
||||
},
|
||||
py::arg("function_name"), py::arg("inputs"));
|
||||
|
||||
// A Ref<Tensor> needs to be bound because we use it as a base for the
|
||||
// ndarray (the array retains a reference to it). Users should not encounter
|
||||
// this unless if they go mucking through the array internals.
|
||||
py::class_<Ref<Tensor>>(m, "TensorRef");
|
||||
}
|
|
@ -6,6 +6,9 @@ add_subdirectory(Python)
|
|||
add_subdirectory(Typing)
|
||||
add_subdirectory(runtime)
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
add_subdirectory(Backend/RefJIT)
|
||||
endif()
|
||||
if(NPCOMP_ENABLE_IREE)
|
||||
add_subdirectory(Backend/IREE)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# RUN: %PYTHON %s
|
||||
|
||||
import numpy as np
|
||||
|
||||
from npcomp.compiler.backend import refjit
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import logging
|
||||
from npcomp.compiler import test_config
|
||||
from npcomp.compiler.target import *
|
||||
|
||||
# TODO: This should all exist in a high level API somewhere.
|
||||
from _npcomp import mlir
|
||||
|
||||
logging.enable()
|
||||
|
||||
|
||||
def compile_function(f):
|
||||
fe = ImportFrontend(config=test_config.create_test_config(
|
||||
target_factory=GenericTarget32))
|
||||
fe.import_global_function(f)
|
||||
compiler = refjit.CompilerBackend()
|
||||
vm_blob = compiler.compile(fe.ir_module)
|
||||
loaded_m = compiler.load(vm_blob)
|
||||
return loaded_m[f.__name__]
|
||||
|
||||
|
||||
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
|
||||
(2, 1)))
|
||||
|
||||
a = np.asarray([1.0, 2.0], dtype=np.float32)
|
||||
b = np.asarray([3.0, 4.0], dtype=np.float32)
|
||||
|
||||
|
||||
@compile_function
|
||||
def global_add():
|
||||
return np.add(a, np.add(b, a))
|
|
@ -0,0 +1,79 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from _npcomp import mlir
|
||||
from npcomp.compiler import logging
|
||||
|
||||
__all__ = [
|
||||
"is_enabled",
|
||||
"CompilerBackend",
|
||||
]
|
||||
|
||||
FRONTEND_PASSES = (
|
||||
"npcomp-cpa-type-inference",
|
||||
"numpy-public-functions-to-tensor",
|
||||
"convert-numpy-to-tcf",
|
||||
"canonicalize",
|
||||
"convert-scf-to-std",
|
||||
)
|
||||
|
||||
_refjit = None
|
||||
|
||||
|
||||
def _get_refjit():
|
||||
"""Dynamically resolves the refjit backend native module."""
|
||||
global _refjit
|
||||
if _refjit is not None:
|
||||
return _refjit
|
||||
try:
|
||||
from _npcomp.backend import refjit as imported_refjit
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The npcomp native module was not compiled with refjit support")
|
||||
_refjit = imported_refjit
|
||||
return _refjit
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""Returns whether the backend is enabled for the current build."""
|
||||
try:
|
||||
_get_refjit()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._refjit = _get_refjit()
|
||||
self._debug = logging.debug_enabled()
|
||||
|
||||
def compile(self, imported_ir_module: mlir.ir.ModuleOp):
|
||||
"""Compiles an imported module.
|
||||
|
||||
Args:
|
||||
imported_ir_module: The MLIR module as imported from the ImportFrontend.
|
||||
Returns:
|
||||
An opaque, backend specific module object that can be passed to load.
|
||||
The object may actually be something more specific to the backend (i.e.
|
||||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||
pm.addPassPipelines(*FRONTEND_PASSES)
|
||||
pm.run(imported_ir_module)
|
||||
if self._debug:
|
||||
logging.debug("Frontend IR:{}", imported_ir_module.to_asm())
|
||||
|
||||
jit_module = self._refjit.JITModule.from_mlir(imported_ir_module, [])
|
||||
return jit_module
|
||||
|
||||
def load(self, jit_module):
|
||||
"""Loads a compiled artifact into the runtime.
|
||||
|
||||
Since this is a JIT instead of an AOT compiler,
|
||||
"""
|
|
@ -21,6 +21,10 @@ if(NPCOMP_ENABLE_IREE)
|
|||
list(APPEND NPCOMP_PYEXT_LIBADD NPCOMPBackendIREEPythonModule)
|
||||
endif()
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
list(APPEND NPCOMP_PYEXT_LIBADD NPCOMPBackendRefJITPythonModule)
|
||||
endif()
|
||||
|
||||
# TODO(laurenzo): Add a config setting to control this.
|
||||
# set(NPCOMP_PYEXT_LINK_MODE MODULE)
|
||||
# set(NPCOMP_PYEXT_LIBADD "")
|
||||
|
|
|
@ -14,6 +14,10 @@
|
|||
#include "npcomp/Python/PybindUtils.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
#ifdef NPCOMP_ENABLE_REFJIT
|
||||
#include "npcomp/Backend/RefJIT/PythonModule.h"
|
||||
#endif
|
||||
|
||||
#ifdef NPCOMP_ENABLE_IREE
|
||||
#include "npcomp/Backend/IREE/PythonModule.h"
|
||||
#endif
|
||||
|
@ -81,6 +85,12 @@ PYBIND11_MODULE(_npcomp, m) {
|
|||
auto backend_m = m.def_submodule("backend", "Backend support");
|
||||
(void)backend_m;
|
||||
|
||||
#ifdef NPCOMP_ENABLE_REFJIT
|
||||
auto refjit_m =
|
||||
backend_m.def_submodule("refjit", "Reference CPU Jit Backend");
|
||||
::npcomp::python::defineBackendRefJitModule(refjit_m);
|
||||
#endif
|
||||
|
||||
#ifdef NPCOMP_ENABLE_IREE
|
||||
auto iree_m = backend_m.def_submodule("iree", "IREE backend support");
|
||||
defineBackendIREEModule(iree_m);
|
||||
|
|
Loading…
Reference in New Issue