diff --git a/external/llvm-project b/external/llvm-project index 53a0d45db..5fef6ce0c 160000 --- a/external/llvm-project +++ b/external/llvm-project @@ -1 +1 @@ -Subproject commit 53a0d45db6d0f33dfbb724c99ce2560ae25473c2 +Subproject commit 5fef6ce0cce05a9fb05f47c9d62f3724377ea076 diff --git a/frontends/pytorch/examples/mul_maximum_e2e.py b/frontends/pytorch/examples/mul_maximum_e2e.py index 49bafa5f2..2cb0d8c21 100644 --- a/frontends/pytorch/examples/mul_maximum_e2e.py +++ b/frontends/pytorch/examples/mul_maximum_e2e.py @@ -3,9 +3,16 @@ # See frontends/pytorch/LICENSE for license information. import sys +import numpy as np import torch import torch_mlir +import npcomp +from npcomp.compiler.pytorch.backend.refjit import * +from npcomp.compiler.utils import logging + +logging.enable() + lhs = torch.ones((4, 6, 1)) rhs = torch.ones((1, 1, 3)) * 0.6 bias = torch.ones((1, 1, 3)) * 0.2 @@ -17,8 +24,13 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f: result = result + bias f.returns([result]) -print(f"Result(f{result.size()}) = {result}", file=sys.stderr) -# TODO: Currently need to route through: -# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \ -# -numpy-public-functions-to-tensor -canonicalize -mb.module.operation.print() +backend = CompilerBackend() +jit_module = backend.load(backend.compile(mb.module)) + +jit_result = jit_module.mul_maximum(lhs.numpy(), rhs.numpy(), threshold.numpy(), + bias.numpy()) + +print(f"PyTorch Result = {result.numpy()}", file=sys.stderr) +print(f"JIT Result = {jit_result}", file=sys.stderr) + +np.testing.assert_allclose(result.numpy(), jit_result) diff --git a/include/npcomp/Backend/RefJIT/PythonModule.h b/include/npcomp/Backend/RefJIT/PythonModule.h index b2e62737b..737499455 100644 --- a/include/npcomp/Backend/RefJIT/PythonModule.h +++ b/include/npcomp/Backend/RefJIT/PythonModule.h @@ -15,7 +15,7 @@ namespace npcomp { namespace python { /// Defines an "refjit" module with backend support definitions. -void defineBackendRefJitModule(py::module m); +void defineBackendRefJitModule(py::module &m); } // namespace python } // namespace npcomp diff --git a/include/npcomp/Python/PybindUtils.h b/include/npcomp/Python/PybindUtils.h index ec7568f8d..06efaa851 100644 --- a/include/npcomp/Python/PybindUtils.h +++ b/include/npcomp/Python/PybindUtils.h @@ -15,6 +15,9 @@ #include #include +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #include "llvm/ADT/Optional.h" namespace py = pybind11; @@ -25,6 +28,45 @@ namespace detail { template struct type_caster> : optional_caster> {}; +/// Casts object -> MlirContext. +template <> struct type_caster { + PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); + bool load(handle src, bool) { + auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR); + value = mlirPythonCapsuleToContext(capsule.ptr()); + if (mlirContextIsNull(value)) { + return false; + } + return true; + } +}; + +/// Casts object -> MlirModule. +template <> struct type_caster { + PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); + bool load(handle src, bool) { + auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR); + value = mlirPythonCapsuleToModule(capsule.ptr()); + if (mlirModuleIsNull(value)) { + return false; + } + return true; + } +}; + +/// Casts object -> MlirPassManager. +template <> struct type_caster { + PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); + bool load(handle src, bool) { + auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + if (mlirPassManagerIsNull(value)) { + return false; + } + return true; + } +}; + } // namespace detail } // namespace pybind11 diff --git a/lib/Backend/RefJIT/PythonModule.cpp b/lib/Backend/RefJIT/PythonModule.cpp index 7fb33f80f..ad96ac2a4 100644 --- a/lib/Backend/RefJIT/PythonModule.cpp +++ b/lib/Backend/RefJIT/PythonModule.cpp @@ -10,6 +10,8 @@ #include "pybind11/numpy.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Pass.h" #include "npcomp/Python/MlirIr.h" #include "npcomp/Python/MlirPass.h" #include "npcomp/RefBackend/JITHelpers/JITModule.h" @@ -84,20 +86,22 @@ py::array wrapTensorAsArray(Ref tensor) { /*base=*/std::move(pyTensor)); } -void npcomp::python::defineBackendRefJitModule(py::module m) { - m.def("build_backend_compilation_pipeline", [](PyPassManager &pm) { - JITModule::buildBackendCompilationPipeline(pm.passManager); +void npcomp::python::defineBackendRefJitModule(py::module &m) { + m.def("build_backend_compilation_pipeline", [](MlirPassManager capiPm) { + mlir::PassManager *pm = unwrap(capiPm); + JITModule::buildBackendCompilationPipeline(*pm); }); py::class_(m, "JITModule") .def_static( "from_compiled_module", - [](PyModuleOp module, std::vector pySharedLibs) + [](MlirModule capiModule, std::vector pySharedLibs) -> std::unique_ptr { SmallVector sharedLibs(pySharedLibs.begin(), pySharedLibs.end()); - auto jitModule = checkError( - JITModule::fromCompiledModule(module.moduleOp, sharedLibs), - "error creating JITModule: "); + auto module = unwrap(capiModule); + auto jitModule = + checkError(JITModule::fromCompiledModule(module, sharedLibs), + "error creating JITModule: "); return jitModule; }, py::arg("module"), py::arg("shared_libs")) diff --git a/lib/Conversion/TCFToStd/TCFToStd.cpp b/lib/Conversion/TCFToStd/TCFToStd.cpp index 86e54d5b6..9ddb89743 100644 --- a/lib/Conversion/TCFToStd/TCFToStd.cpp +++ b/lib/Conversion/TCFToStd/TCFToStd.cpp @@ -69,8 +69,8 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) { loc, resultType, rhs, broadcastedShape); Value binaryOpResult; if (isa(op)) { - binaryOpResult = rewriter.create( - loc, result.getType(), lhsBroadcasted, rhsBroadcasted); + binaryOpResult = rewriter.create(loc, result.getType(), + lhsBroadcasted, rhsBroadcasted); } else if (isa(op)) { // XXX: remove TCP dep // XXX: remove TCP ops from TCP @@ -79,8 +79,8 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) { binaryOpResult = rewriter.create(loc, pred, lhsBroadcasted, rhsBroadcasted); } else if (isa(op)) { - binaryOpResult = rewriter.create( - loc, result.getType(), lhsBroadcasted, rhsBroadcasted); + binaryOpResult = rewriter.create(loc, result.getType(), + lhsBroadcasted, rhsBroadcasted); } else { op->dump(); llvm::report_fatal_error( diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 6b998e417..67498e1f7 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -194,17 +194,17 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // // Also, converting to linalg herevopens up a lot of optimization // opportunities. - pm.addPass(createConvertElementwiseToLinalgPass()); + pm.addNestedPass(createConvertElementwiseToLinalgPass()); if (options.optimize) { - pm.addPass(createLinalgFusionOfTensorOpsPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); + pm.addNestedPass(createLinalgFusionOfTensorOpsPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); } // Lower shape constraints before we enter tensor->memref conversion. // That is, we expand shape.cstr_* ops to eager error handling code. - pm.addPass(createConvertShapeConstraintsPass()); + pm.addNestedPass(createConvertShapeConstraintsPass()); // Run shape canonicalizations. In particular, this erases shape.assuming, // now that we have converted shape constraints. // TODO: This is kind of ugly. Either we use pass options or a constructor @@ -227,12 +227,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // (tensor_load / tensor_to_memref) in the IR. // Bufferize the TCP dialect. - pm.addPass(createTCPBufferizePass()); + pm.addNestedPass(createTCPBufferizePass()); // Lower tensor-valued constants to refback.global. pm.addPass(createLowerConstantTensorsToMemrefPass()); // refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument. // We need to resolve this to std.alloc which takes individual extents. - pm.addPass(createLowerAllocMemRefOpsPass()); + pm.addNestedPass(createLowerAllocMemRefOpsPass()); // Lower shape ops to std. // TODO: This should in principle be moved before tensor->memref conversion. // But some of the tensor->memref lowerings above use shape.get_extent. For @@ -241,8 +241,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( pm.addPass(createConvertShapeToStandardPass()); // Run some upstream bufferization passes to finish bufferization. - pm.addPass(createStdBufferizePass()); - pm.addPass(createSCFBufferizePass()); + pm.addNestedPass(createStdBufferizePass()); + pm.addNestedPass(createSCFBufferizePass()); pm.addPass(createLinalgBufferizePass()); pm.addPass(createFuncBufferizePass()); @@ -252,8 +252,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // At this point, we have lots of loose stuff floating around from lowering, // so it's a good time to do some general cleanups. if (options.optimize) { - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); } // -------------------------------------------------------------------------- @@ -264,12 +264,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // Lower linalg ops to loops. // TODO: Do some linalg optimizations like tiling here. - pm.addPass(createConvertLinalgToLoopsPass()); + pm.addNestedPass(createConvertLinalgToLoopsPass()); // Run a some cleanups. if (options.optimize) { - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); } // -------------------------------------------------------------------------- @@ -277,7 +277,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // -------------------------------------------------------------------------- // Convert scf to std control flow in preparation for going to LLVM. - pm.addPass(createLowerToCFGPass()); + pm.addNestedPass(createLowerToCFGPass()); // Convert functions signatures and other constructs that interface with the // runtime to the `refbackrt` dialect. @@ -291,8 +291,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline( // Although LLVM will clean everything up eventually, for the sake of IR // clarity while still in MLIR, run some cleanups. if (options.optimize) { - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); } } @@ -305,6 +305,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline( // case of invalid broadcasts. // // TCP does not. So we need to reify the broadcasting and error checking. + // These all run at the module level. pm.addPass(createConvertTCFToStdPass()); pm.addPass(createConvertTCFToLinalgPass()); pm.addPass(createConvertTCFToTCPPass()); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index a155db114..c6011b373 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -9,7 +9,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy # TODO: Make the runtime library work for windows. ${CMAKE_BINARY_DIR}/lib/libNPCOMPCompilerRuntimeShlib.so - ${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/_refjit_resources/libNPCOMPCompilerRuntimeShlib.so + ${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/generic/backend/libNPCOMPCompilerRuntimeShlib.so ) add_dependencies(NPCOMPPythonResources NPCOMPCompilerRuntimeShlib diff --git a/python/NpcompModule.cpp b/python/NpcompModule.cpp index c460c9154..e120ec365 100644 --- a/python/NpcompModule.cpp +++ b/python/NpcompModule.cpp @@ -12,6 +12,7 @@ #include "npcomp/Python/MlirInit.h" #include "npcomp/Python/NpcompModule.h" #include "npcomp/Python/PybindUtils.h" +#include "npcomp-c/Registration.h" #include "llvm/Support/CommandLine.h" #ifdef NPCOMP_ENABLE_REFJIT @@ -26,7 +27,7 @@ namespace mlir { namespace npcomp { namespace python { -void defineLLVMModule(pybind11::module m) { +static void defineLLVMModule(pybind11::module m) { m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); }); m.def("add_option", [](std::string name, llvm::Optional value) { @@ -56,6 +57,12 @@ void defineLLVMModule(pybind11::module m) { py::arg("name")); } +static void defineGlobals(py::module &m) { + m.def("register_dialects", [](MlirContext context) { + npcompRegisterAllDialects(context); + }); +} + PYBIND11_MODULE(_npcomp, m) { // Guard the once init to happen once per process (vs module, which in // mondo builds can happen multiple times). @@ -64,6 +71,8 @@ PYBIND11_MODULE(_npcomp, m) { m.doc() = "Npcomp native python bindings"; + // TODO: Retire the llvm, mlir, passes, and dialect modules in favor of + // upstream Python bindings. auto llvm_m = m.def_submodule("llvm", "LLVM interop"); defineLLVMModule(llvm_m); @@ -81,6 +90,9 @@ PYBIND11_MODULE(_npcomp, m) { auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects"); defineNpcompDialect(npcomp_dialect); + // Globals. + defineGlobals(m); + // Optional backend modules. auto backend_m = m.def_submodule("backend", "Backend support"); (void)backend_m; diff --git a/python/npcomp/compiler/_refjit_resources/__init__.py b/python/npcomp/compiler/generic/backend/__init__.py similarity index 100% rename from python/npcomp/compiler/_refjit_resources/__init__.py rename to python/npcomp/compiler/generic/backend/__init__.py diff --git a/python/npcomp/compiler/generic/backend/refjit.py b/python/npcomp/compiler/generic/backend/refjit.py new file mode 100644 index 000000000..05d5ff2ee --- /dev/null +++ b/python/npcomp/compiler/generic/backend/refjit.py @@ -0,0 +1,66 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os + +_refjit = None + +BACKEND_PASSES = ( + "func(convert-scf-to-std)", + "func(canonicalize)", + "func(tcf-shape-refinement)", +) + + +def get_refjit(): + """Dynamically resolves the refjit backend native module.""" + global _refjit + if _refjit is not None: + return _refjit + try: + from _npcomp.backend import refjit as imported_refjit + except ImportError: + raise ImportError( + "The npcomp native module was not compiled with refjit support") + _refjit = imported_refjit + return _refjit + + +def is_enabled() -> bool: + """Returns whether the backend is enabled for the current build.""" + try: + _get_refjit() + return True + except ImportError: + return False + + +def get_runtime_libs(): + # The _refjit_resources directory is at the npcomp.compiler level. + resources_dir = os.path.join(os.path.dirname(__file__)) + return [os.path.join(resources_dir, "libNPCOMPCompilerRuntimeShlib.so")] + + +class JitModuleInvoker: + """Wrapper around a native JitModule for calling functions.""" + + def __init__(self, jit_module): + super().__init__() + self._jit_module = jit_module + + def __getattr__(self, function_name): + return self.__getitem__(function_name) + + def __getitem__(self, function_name): + + def invoke(*args): + results = self._jit_module.invoke(function_name, args) + if len(results) == 1: + # De-tuple. + return results[0] + else: + return tuple(results) + + invoke.__isnpcomp__ = True + return invoke diff --git a/python/npcomp/compiler/numpy/backend/refjit.py b/python/npcomp/compiler/numpy/backend/refjit.py index 5994267b7..9c8bbf590 100644 --- a/python/npcomp/compiler/numpy/backend/refjit.py +++ b/python/npcomp/compiler/numpy/backend/refjit.py @@ -4,7 +4,11 @@ import os -from _npcomp import mlir +from mlir.ir import * +from mlir.passmanager import * +from _npcomp import register_dialects +from _npcomp import mlir as legacy_mlir +from npcomp.compiler.generic.backend import refjit as refjit_backend from npcomp.compiler.utils import logging __all__ = [ @@ -13,45 +17,16 @@ __all__ = [ ] FRONTEND_PASSES = ( - "npcomp-cpa-type-inference", + "func(npcomp-cpa-type-inference)", "numpy-public-functions-to-tensor", - "convert-numpy-to-tcf", - "convert-scf-to-std", - "canonicalize", - "tcf-shape-refinement", + "func(convert-numpy-to-tcf)", + "func(convert-scf-to-std)", + "func(canonicalize)", + "func(tcf-shape-refinement)", ) -_refjit = None - - -def _get_refjit(): - """Dynamically resolves the refjit backend native module.""" - global _refjit - if _refjit is not None: - return _refjit - try: - from _npcomp.backend import refjit as imported_refjit - except ImportError: - raise ImportError( - "The npcomp native module was not compiled with refjit support") - _refjit = imported_refjit - return _refjit - - -def is_enabled() -> bool: - """Returns whether the backend is enabled for the current build.""" - try: - _get_refjit() - return True - except ImportError: - return False - - -def get_runtime_libs(): - # The _refjit_resources directory is at the npcomp.compiler level. - resources_dir = os.path.join(os.path.dirname(__file__), "..", "..", - "_refjit_resources") - return [os.path.join(resources_dir, "libNPCOMPCompilerRuntimeShlib.so")] +# Re-export. +is_enabled = refjit_backend.is_enabled class CompilerBackend: @@ -59,37 +34,41 @@ class CompilerBackend: def __init__(self): super().__init__() - self._refjit = _get_refjit() + self._refjit = refjit_backend.get_refjit() self._debug = logging.debug_enabled() - def compile(self, imported_ir_module: mlir.ir.ModuleOp): + def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp): """Compiles an imported module. Args: - imported_ir_module: The MLIR module as imported from the ImportFrontend. + legacy_imported_ir_module: The MLIR module as imported from the + ImportFrontend. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ - # Frontend. - pm = mlir.passes.PassManager(imported_ir_module.context) - pm.addPassPipelines(*FRONTEND_PASSES) - pm.run(imported_ir_module) - if self._debug: - logging.debug("Frontend IR:{}", imported_ir_module.to_asm()) + # TODO: Once transitioned to new Python API, don't reparse the module. + with Context() as context: + register_dialects(context) + imported_module = Module.parse(legacy_imported_ir_module.to_asm()) + # Frontend. + pm = PassManager.parse(",".join(FRONTEND_PASSES)) + pm.run(imported_module) + if self._debug: + logging.debug("Frontend IR:{}", imported_module) - # Backend. - # Note that this is a separate pass manager purely to aid in debugging. - pm = mlir.passes.PassManager(imported_ir_module.context) - self._refjit.build_backend_compilation_pipeline(pm) - pm.run(imported_ir_module) - if self._debug: - logging.debug("Backend IR:{}", imported_ir_module.to_asm()) + # Backend. + # Note that this is a separate pass manager purely to aid in debugging. + pm = PassManager() + self._refjit.build_backend_compilation_pipeline(pm) + pm.run(imported_module) + if self._debug: + logging.debug("Backend IR:{}", imported_module) jit_module = self._refjit.JITModule.from_compiled_module( - imported_ir_module, get_runtime_libs()) + imported_module, refjit_backend.get_runtime_libs()) return jit_module def load(self, jit_module): @@ -97,25 +76,4 @@ class CompilerBackend: Since this is a JIT instead of an AOT compiler, """ - return JitModuleInvoker(jit_module) - - -class JitModuleInvoker: - """Wrapper around a native JitModule for calling functions.""" - - def __init__(self, jit_module): - super().__init__() - self._jit_module = jit_module - - def __getitem__(self, function_name): - - def invoke(*args): - results = self._jit_module.invoke(function_name, args) - if len(results) == 1: - # De-tuple. - return results[0] - else: - return tuple(results) - - invoke.__isnpcomp__ = True - return invoke + return refjit_backend.JitModuleInvoker(jit_module) diff --git a/python/npcomp/compiler/pytorch/__init__.py b/python/npcomp/compiler/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/npcomp/compiler/pytorch/backend/__init__.py b/python/npcomp/compiler/pytorch/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/npcomp/compiler/pytorch/backend/refjit.py b/python/npcomp/compiler/pytorch/backend/refjit.py new file mode 100644 index 000000000..17e268a91 --- /dev/null +++ b/python/npcomp/compiler/pytorch/backend/refjit.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os + +from mlir.ir import * +from mlir.passmanager import * +from npcomp.compiler.generic.backend import refjit as refjit_backend +from npcomp.compiler.utils import logging + +__all__ = [ + "is_enabled", + "CompilerBackend", +] + +FRONTEND_PASSES = ("func(aten-recognize-kernels)", "func(convert-aten-to-tcf)", + "numpy-public-functions-to-tensor", "canonicalize") + +# Re-export. +is_enabled = refjit_backend.is_enabled + + +class CompilerBackend: + """Main entry-point for the backend.""" + + def __init__(self): + super().__init__() + self._refjit = refjit_backend.get_refjit() + self._debug = logging.debug_enabled() + + def compile(self, imported_module: Module): + """Compiles an imported module. + + Args: + imported_module: The MLIR module consisting of funcs in the torch + dialect. + Returns: + An opaque, backend specific module object that can be passed to load. + The object may actually be something more specific to the backend (i.e. + for IREE, it is a serialized VM flatbuffer) but the contract is that + it is operated on by methods on this class. + """ + # TODO: Once transitioned to new Python API, don't reparse the module. + with Context() as context: + # Frontend. + pm = PassManager.parse(",".join(FRONTEND_PASSES)) + pm.run(imported_module) + if self._debug: + logging.debug("Frontend IR:{}", imported_module) + + # Backend. + # Note that this is a separate pass manager purely to aid in debugging. + pm = PassManager() + self._refjit.build_backend_compilation_pipeline(pm) + pm.run(imported_module) + if self._debug: + logging.debug("Backend IR:{}", imported_module) + + jit_module = self._refjit.JITModule.from_compiled_module( + imported_module, refjit_backend.get_runtime_libs()) + return jit_module + + def load(self, jit_module): + """Loads a compiled artifact into the runtime. + + Since this is a JIT instead of an AOT compiler, + """ + return refjit_backend.JitModuleInvoker(jit_module)