mirror of https://github.com/llvm/torch-mlir
Repurpose numpy-compiler compiler/runtime flow for PyTorch.
* A bit gross because I took the chance to upgrade all of the backend bits to the new MLIR Python bindings and we still co-mingle the old and new for now. * Since the Python created PassManagers are configured for explicit nesting, I had to upgrade some of the pass pipelines to be explicit. * The demo in mul_maximum_e2e.py now compiles, runs through PyTorch and through the JIT, prints and asserts the same results. * I am not claiming that this is the prettiest API in this patch: consider that this is just directly using low-level APIs and there should be an intervening high level API.pull/113/head
parent
d1488c8572
commit
b4c7ae1e0c
|
@ -1 +1 @@
|
|||
Subproject commit 53a0d45db6d0f33dfbb724c99ce2560ae25473c2
|
||||
Subproject commit 5fef6ce0cce05a9fb05f47c9d62f3724377ea076
|
|
@ -3,9 +3,16 @@
|
|||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend.refjit import *
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
logging.enable()
|
||||
|
||||
lhs = torch.ones((4, 6, 1))
|
||||
rhs = torch.ones((1, 1, 3)) * 0.6
|
||||
bias = torch.ones((1, 1, 3)) * 0.2
|
||||
|
@ -17,8 +24,13 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
|
|||
result = result + bias
|
||||
f.returns([result])
|
||||
|
||||
print(f"Result(f{result.size()}) = {result}", file=sys.stderr)
|
||||
# TODO: Currently need to route through:
|
||||
# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \
|
||||
# -numpy-public-functions-to-tensor -canonicalize
|
||||
mb.module.operation.print()
|
||||
backend = CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
jit_result = jit_module.mul_maximum(lhs.numpy(), rhs.numpy(), threshold.numpy(),
|
||||
bias.numpy())
|
||||
|
||||
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
|
||||
print(f"JIT Result = {jit_result}", file=sys.stderr)
|
||||
|
||||
np.testing.assert_allclose(result.numpy(), jit_result)
|
||||
|
|
|
@ -15,7 +15,7 @@ namespace npcomp {
|
|||
namespace python {
|
||||
|
||||
/// Defines an "refjit" module with backend support definitions.
|
||||
void defineBackendRefJitModule(py::module m);
|
||||
void defineBackendRefJitModule(py::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Pass.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
@ -25,6 +28,45 @@ namespace detail {
|
|||
template <typename T>
|
||||
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
|
||||
|
||||
/// Casts object -> MlirContext.
|
||||
template <> struct type_caster<MlirContext> {
|
||||
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
value = mlirPythonCapsuleToContext(capsule.ptr());
|
||||
if (mlirContextIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirModule.
|
||||
template <> struct type_caster<MlirModule> {
|
||||
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
value = mlirPythonCapsuleToModule(capsule.ptr());
|
||||
if (mlirModuleIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/// Casts object -> MlirPassManager.
|
||||
template <> struct type_caster<MlirPassManager> {
|
||||
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
|
||||
bool load(handle src, bool) {
|
||||
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
value = mlirPythonCapsuleToPassManager(capsule.ptr());
|
||||
if (mlirPassManagerIsNull(value)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
|
||||
#include "pybind11/numpy.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Pass.h"
|
||||
#include "npcomp/Python/MlirIr.h"
|
||||
#include "npcomp/Python/MlirPass.h"
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
|
@ -84,19 +86,21 @@ py::array wrapTensorAsArray(Ref<Tensor> tensor) {
|
|||
/*base=*/std::move(pyTensor));
|
||||
}
|
||||
|
||||
void npcomp::python::defineBackendRefJitModule(py::module m) {
|
||||
m.def("build_backend_compilation_pipeline", [](PyPassManager &pm) {
|
||||
JITModule::buildBackendCompilationPipeline(pm.passManager);
|
||||
void npcomp::python::defineBackendRefJitModule(py::module &m) {
|
||||
m.def("build_backend_compilation_pipeline", [](MlirPassManager capiPm) {
|
||||
mlir::PassManager *pm = unwrap(capiPm);
|
||||
JITModule::buildBackendCompilationPipeline(*pm);
|
||||
});
|
||||
py::class_<JITModule>(m, "JITModule")
|
||||
.def_static(
|
||||
"from_compiled_module",
|
||||
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
|
||||
[](MlirModule capiModule, std::vector<std::string> pySharedLibs)
|
||||
-> std::unique_ptr<JITModule> {
|
||||
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
|
||||
pySharedLibs.end());
|
||||
auto jitModule = checkError(
|
||||
JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
|
||||
auto module = unwrap(capiModule);
|
||||
auto jitModule =
|
||||
checkError(JITModule::fromCompiledModule(module, sharedLibs),
|
||||
"error creating JITModule: ");
|
||||
return jitModule;
|
||||
},
|
||||
|
|
|
@ -69,8 +69,8 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
|||
loc, resultType, rhs, broadcastedShape);
|
||||
Value binaryOpResult;
|
||||
if (isa<tcf::AddOp>(op)) {
|
||||
binaryOpResult = rewriter.create<AddFOp>(
|
||||
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
||||
binaryOpResult = rewriter.create<AddFOp>(loc, result.getType(),
|
||||
lhsBroadcasted, rhsBroadcasted);
|
||||
} else if (isa<tcf::MaxOp>(op)) {
|
||||
// XXX: remove TCP dep
|
||||
// XXX: remove TCP ops from TCP
|
||||
|
@ -79,8 +79,8 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
|||
binaryOpResult =
|
||||
rewriter.create<SelectOp>(loc, pred, lhsBroadcasted, rhsBroadcasted);
|
||||
} else if (isa<tcf::MulOp>(op)) {
|
||||
binaryOpResult = rewriter.create<MulFOp>(
|
||||
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
||||
binaryOpResult = rewriter.create<MulFOp>(loc, result.getType(),
|
||||
lhsBroadcasted, rhsBroadcasted);
|
||||
} else {
|
||||
op->dump();
|
||||
llvm::report_fatal_error(
|
||||
|
|
|
@ -194,17 +194,17 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
//
|
||||
// Also, converting to linalg herevopens up a lot of optimization
|
||||
// opportunities.
|
||||
pm.addPass(createConvertElementwiseToLinalgPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
|
||||
|
||||
if (options.optimize) {
|
||||
pm.addPass(createLinalgFusionOfTensorOpsPass());
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
pm.addNestedPass<FuncOp>(createLinalgFusionOfTensorOpsPass());
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Lower shape constraints before we enter tensor->memref conversion.
|
||||
// That is, we expand shape.cstr_* ops to eager error handling code.
|
||||
pm.addPass(createConvertShapeConstraintsPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertShapeConstraintsPass());
|
||||
// Run shape canonicalizations. In particular, this erases shape.assuming,
|
||||
// now that we have converted shape constraints.
|
||||
// TODO: This is kind of ugly. Either we use pass options or a constructor
|
||||
|
@ -227,12 +227,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
// (tensor_load / tensor_to_memref) in the IR.
|
||||
|
||||
// Bufferize the TCP dialect.
|
||||
pm.addPass(createTCPBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createTCPBufferizePass());
|
||||
// Lower tensor-valued constants to refback.global.
|
||||
pm.addPass(createLowerConstantTensorsToMemrefPass());
|
||||
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
|
||||
// We need to resolve this to std.alloc which takes individual extents.
|
||||
pm.addPass(createLowerAllocMemRefOpsPass());
|
||||
pm.addNestedPass<FuncOp>(createLowerAllocMemRefOpsPass());
|
||||
// Lower shape ops to std.
|
||||
// TODO: This should in principle be moved before tensor->memref conversion.
|
||||
// But some of the tensor->memref lowerings above use shape.get_extent. For
|
||||
|
@ -241,8 +241,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
pm.addPass(createConvertShapeToStandardPass());
|
||||
|
||||
// Run some upstream bufferization passes to finish bufferization.
|
||||
pm.addPass(createStdBufferizePass());
|
||||
pm.addPass(createSCFBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createStdBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
|
||||
pm.addPass(createLinalgBufferizePass());
|
||||
pm.addPass(createFuncBufferizePass());
|
||||
|
||||
|
@ -252,8 +252,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
// At this point, we have lots of loose stuff floating around from lowering,
|
||||
// so it's a good time to do some general cleanups.
|
||||
if (options.optimize) {
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
@ -264,12 +264,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
|
||||
// Lower linalg ops to loops.
|
||||
// TODO: Do some linalg optimizations like tiling here.
|
||||
pm.addPass(createConvertLinalgToLoopsPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
||||
|
||||
// Run a some cleanups.
|
||||
if (options.optimize) {
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
@ -277,7 +277,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
// --------------------------------------------------------------------------
|
||||
|
||||
// Convert scf to std control flow in preparation for going to LLVM.
|
||||
pm.addPass(createLowerToCFGPass());
|
||||
pm.addNestedPass<FuncOp>(createLowerToCFGPass());
|
||||
|
||||
// Convert functions signatures and other constructs that interface with the
|
||||
// runtime to the `refbackrt` dialect.
|
||||
|
@ -291,8 +291,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
|||
// Although LLVM will clean everything up eventually, for the sake of IR
|
||||
// clarity while still in MLIR, run some cleanups.
|
||||
if (options.optimize) {
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -305,6 +305,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
|
|||
// case of invalid broadcasts.
|
||||
//
|
||||
// TCP does not. So we need to reify the broadcasting and error checking.
|
||||
// These all run at the module level.
|
||||
pm.addPass(createConvertTCFToStdPass());
|
||||
pm.addPass(createConvertTCFToLinalgPass());
|
||||
pm.addPass(createConvertTCFToTCPPass());
|
||||
|
|
|
@ -9,7 +9,7 @@ add_custom_command(
|
|||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
# TODO: Make the runtime library work for windows.
|
||||
${CMAKE_BINARY_DIR}/lib/libNPCOMPCompilerRuntimeShlib.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/_refjit_resources/libNPCOMPCompilerRuntimeShlib.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/generic/backend/libNPCOMPCompilerRuntimeShlib.so
|
||||
)
|
||||
add_dependencies(NPCOMPPythonResources
|
||||
NPCOMPCompilerRuntimeShlib
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "npcomp/Python/MlirInit.h"
|
||||
#include "npcomp/Python/NpcompModule.h"
|
||||
#include "npcomp/Python/PybindUtils.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
#ifdef NPCOMP_ENABLE_REFJIT
|
||||
|
@ -26,7 +27,7 @@ namespace mlir {
|
|||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
void defineLLVMModule(pybind11::module m) {
|
||||
static void defineLLVMModule(pybind11::module m) {
|
||||
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
|
||||
m.def("add_option",
|
||||
[](std::string name, llvm::Optional<std::string> value) {
|
||||
|
@ -56,6 +57,12 @@ void defineLLVMModule(pybind11::module m) {
|
|||
py::arg("name"));
|
||||
}
|
||||
|
||||
static void defineGlobals(py::module &m) {
|
||||
m.def("register_dialects", [](MlirContext context) {
|
||||
npcompRegisterAllDialects(context);
|
||||
});
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_npcomp, m) {
|
||||
// Guard the once init to happen once per process (vs module, which in
|
||||
// mondo builds can happen multiple times).
|
||||
|
@ -64,6 +71,8 @@ PYBIND11_MODULE(_npcomp, m) {
|
|||
|
||||
m.doc() = "Npcomp native python bindings";
|
||||
|
||||
// TODO: Retire the llvm, mlir, passes, and dialect modules in favor of
|
||||
// upstream Python bindings.
|
||||
auto llvm_m = m.def_submodule("llvm", "LLVM interop");
|
||||
defineLLVMModule(llvm_m);
|
||||
|
||||
|
@ -81,6 +90,9 @@ PYBIND11_MODULE(_npcomp, m) {
|
|||
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
|
||||
defineNpcompDialect(npcomp_dialect);
|
||||
|
||||
// Globals.
|
||||
defineGlobals(m);
|
||||
|
||||
// Optional backend modules.
|
||||
auto backend_m = m.def_submodule("backend", "Backend support");
|
||||
(void)backend_m;
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import os
|
||||
|
||||
_refjit = None
|
||||
|
||||
BACKEND_PASSES = (
|
||||
"func(convert-scf-to-std)",
|
||||
"func(canonicalize)",
|
||||
"func(tcf-shape-refinement)",
|
||||
)
|
||||
|
||||
|
||||
def get_refjit():
|
||||
"""Dynamically resolves the refjit backend native module."""
|
||||
global _refjit
|
||||
if _refjit is not None:
|
||||
return _refjit
|
||||
try:
|
||||
from _npcomp.backend import refjit as imported_refjit
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The npcomp native module was not compiled with refjit support")
|
||||
_refjit = imported_refjit
|
||||
return _refjit
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""Returns whether the backend is enabled for the current build."""
|
||||
try:
|
||||
_get_refjit()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_runtime_libs():
|
||||
# The _refjit_resources directory is at the npcomp.compiler level.
|
||||
resources_dir = os.path.join(os.path.dirname(__file__))
|
||||
return [os.path.join(resources_dir, "libNPCOMPCompilerRuntimeShlib.so")]
|
||||
|
||||
|
||||
class JitModuleInvoker:
|
||||
"""Wrapper around a native JitModule for calling functions."""
|
||||
|
||||
def __init__(self, jit_module):
|
||||
super().__init__()
|
||||
self._jit_module = jit_module
|
||||
|
||||
def __getattr__(self, function_name):
|
||||
return self.__getitem__(function_name)
|
||||
|
||||
def __getitem__(self, function_name):
|
||||
|
||||
def invoke(*args):
|
||||
results = self._jit_module.invoke(function_name, args)
|
||||
if len(results) == 1:
|
||||
# De-tuple.
|
||||
return results[0]
|
||||
else:
|
||||
return tuple(results)
|
||||
|
||||
invoke.__isnpcomp__ = True
|
||||
return invoke
|
|
@ -4,7 +4,11 @@
|
|||
|
||||
import os
|
||||
|
||||
from _npcomp import mlir
|
||||
from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
from _npcomp import register_dialects
|
||||
from _npcomp import mlir as legacy_mlir
|
||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
__all__ = [
|
||||
|
@ -13,45 +17,16 @@ __all__ = [
|
|||
]
|
||||
|
||||
FRONTEND_PASSES = (
|
||||
"npcomp-cpa-type-inference",
|
||||
"func(npcomp-cpa-type-inference)",
|
||||
"numpy-public-functions-to-tensor",
|
||||
"convert-numpy-to-tcf",
|
||||
"convert-scf-to-std",
|
||||
"canonicalize",
|
||||
"tcf-shape-refinement",
|
||||
"func(convert-numpy-to-tcf)",
|
||||
"func(convert-scf-to-std)",
|
||||
"func(canonicalize)",
|
||||
"func(tcf-shape-refinement)",
|
||||
)
|
||||
|
||||
_refjit = None
|
||||
|
||||
|
||||
def _get_refjit():
|
||||
"""Dynamically resolves the refjit backend native module."""
|
||||
global _refjit
|
||||
if _refjit is not None:
|
||||
return _refjit
|
||||
try:
|
||||
from _npcomp.backend import refjit as imported_refjit
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The npcomp native module was not compiled with refjit support")
|
||||
_refjit = imported_refjit
|
||||
return _refjit
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""Returns whether the backend is enabled for the current build."""
|
||||
try:
|
||||
_get_refjit()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_runtime_libs():
|
||||
# The _refjit_resources directory is at the npcomp.compiler level.
|
||||
resources_dir = os.path.join(os.path.dirname(__file__), "..", "..",
|
||||
"_refjit_resources")
|
||||
return [os.path.join(resources_dir, "libNPCOMPCompilerRuntimeShlib.so")]
|
||||
# Re-export.
|
||||
is_enabled = refjit_backend.is_enabled
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
|
@ -59,37 +34,41 @@ class CompilerBackend:
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._refjit = _get_refjit()
|
||||
self._refjit = refjit_backend.get_refjit()
|
||||
self._debug = logging.debug_enabled()
|
||||
|
||||
def compile(self, imported_ir_module: mlir.ir.ModuleOp):
|
||||
def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp):
|
||||
"""Compiles an imported module.
|
||||
|
||||
Args:
|
||||
imported_ir_module: The MLIR module as imported from the ImportFrontend.
|
||||
legacy_imported_ir_module: The MLIR module as imported from the
|
||||
ImportFrontend.
|
||||
Returns:
|
||||
An opaque, backend specific module object that can be passed to load.
|
||||
The object may actually be something more specific to the backend (i.e.
|
||||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
# TODO: Once transitioned to new Python API, don't reparse the module.
|
||||
with Context() as context:
|
||||
register_dialects(context)
|
||||
imported_module = Module.parse(legacy_imported_ir_module.to_asm())
|
||||
# Frontend.
|
||||
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||
pm.addPassPipelines(*FRONTEND_PASSES)
|
||||
pm.run(imported_ir_module)
|
||||
pm = PassManager.parse(",".join(FRONTEND_PASSES))
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Frontend IR:{}", imported_ir_module.to_asm())
|
||||
logging.debug("Frontend IR:{}", imported_module)
|
||||
|
||||
# Backend.
|
||||
# Note that this is a separate pass manager purely to aid in debugging.
|
||||
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||
pm = PassManager()
|
||||
self._refjit.build_backend_compilation_pipeline(pm)
|
||||
pm.run(imported_ir_module)
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Backend IR:{}", imported_ir_module.to_asm())
|
||||
logging.debug("Backend IR:{}", imported_module)
|
||||
|
||||
jit_module = self._refjit.JITModule.from_compiled_module(
|
||||
imported_ir_module, get_runtime_libs())
|
||||
imported_module, refjit_backend.get_runtime_libs())
|
||||
return jit_module
|
||||
|
||||
def load(self, jit_module):
|
||||
|
@ -97,25 +76,4 @@ class CompilerBackend:
|
|||
|
||||
Since this is a JIT instead of an AOT compiler,
|
||||
"""
|
||||
return JitModuleInvoker(jit_module)
|
||||
|
||||
|
||||
class JitModuleInvoker:
|
||||
"""Wrapper around a native JitModule for calling functions."""
|
||||
|
||||
def __init__(self, jit_module):
|
||||
super().__init__()
|
||||
self._jit_module = jit_module
|
||||
|
||||
def __getitem__(self, function_name):
|
||||
|
||||
def invoke(*args):
|
||||
results = self._jit_module.invoke(function_name, args)
|
||||
if len(results) == 1:
|
||||
# De-tuple.
|
||||
return results[0]
|
||||
else:
|
||||
return tuple(results)
|
||||
|
||||
invoke.__isnpcomp__ = True
|
||||
return invoke
|
||||
return refjit_backend.JitModuleInvoker(jit_module)
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import os
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
__all__ = [
|
||||
"is_enabled",
|
||||
"CompilerBackend",
|
||||
]
|
||||
|
||||
FRONTEND_PASSES = ("func(aten-recognize-kernels)", "func(convert-aten-to-tcf)",
|
||||
"numpy-public-functions-to-tensor", "canonicalize")
|
||||
|
||||
# Re-export.
|
||||
is_enabled = refjit_backend.is_enabled
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._refjit = refjit_backend.get_refjit()
|
||||
self._debug = logging.debug_enabled()
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module.
|
||||
|
||||
Args:
|
||||
imported_module: The MLIR module consisting of funcs in the torch
|
||||
dialect.
|
||||
Returns:
|
||||
An opaque, backend specific module object that can be passed to load.
|
||||
The object may actually be something more specific to the backend (i.e.
|
||||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
# TODO: Once transitioned to new Python API, don't reparse the module.
|
||||
with Context() as context:
|
||||
# Frontend.
|
||||
pm = PassManager.parse(",".join(FRONTEND_PASSES))
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Frontend IR:{}", imported_module)
|
||||
|
||||
# Backend.
|
||||
# Note that this is a separate pass manager purely to aid in debugging.
|
||||
pm = PassManager()
|
||||
self._refjit.build_backend_compilation_pipeline(pm)
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Backend IR:{}", imported_module)
|
||||
|
||||
jit_module = self._refjit.JITModule.from_compiled_module(
|
||||
imported_module, refjit_backend.get_runtime_libs())
|
||||
return jit_module
|
||||
|
||||
def load(self, jit_module):
|
||||
"""Loads a compiled artifact into the runtime.
|
||||
|
||||
Since this is a JIT instead of an AOT compiler,
|
||||
"""
|
||||
return refjit_backend.JitModuleInvoker(jit_module)
|
Loading…
Reference in New Issue