Repurpose numpy-compiler compiler/runtime flow for PyTorch.

* A bit gross because I took the chance to upgrade all of the backend bits to the new MLIR Python bindings and we still co-mingle the old and new for now.
* Since the Python created PassManagers are configured for explicit nesting, I had to upgrade some of the pass pipelines to be explicit.
* The demo in mul_maximum_e2e.py now compiles, runs through PyTorch and through the JIT, prints and asserts the same results.
* I am not claiming that this is the prettiest API in this patch: consider that this is just directly using low-level APIs and there should be an intervening high level API.
pull/113/head
Stella Laurenzo 2020-11-10 21:38:13 -08:00
parent d1488c8572
commit b4c7ae1e0c
15 changed files with 277 additions and 113 deletions

@ -1 +1 @@
Subproject commit 53a0d45db6d0f33dfbb724c99ce2560ae25473c2
Subproject commit 5fef6ce0cce05a9fb05f47c9d62f3724377ea076

View File

@ -3,9 +3,16 @@
# See frontends/pytorch/LICENSE for license information.
import sys
import numpy as np
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend.refjit import *
from npcomp.compiler.utils import logging
logging.enable()
lhs = torch.ones((4, 6, 1))
rhs = torch.ones((1, 1, 3)) * 0.6
bias = torch.ones((1, 1, 3)) * 0.2
@ -17,8 +24,13 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
result = result + bias
f.returns([result])
print(f"Result(f{result.size()}) = {result}", file=sys.stderr)
# TODO: Currently need to route through:
# npcomp-opt -aten-recognize-kernels -convert-aten-to-tcf \
# -numpy-public-functions-to-tensor -canonicalize
mb.module.operation.print()
backend = CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
jit_result = jit_module.mul_maximum(lhs.numpy(), rhs.numpy(), threshold.numpy(),
bias.numpy())
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
print(f"JIT Result = {jit_result}", file=sys.stderr)
np.testing.assert_allclose(result.numpy(), jit_result)

View File

@ -15,7 +15,7 @@ namespace npcomp {
namespace python {
/// Defines an "refjit" module with backend support definitions.
void defineBackendRefJitModule(py::module m);
void defineBackendRefJitModule(py::module &m);
} // namespace python
} // namespace npcomp

View File

@ -15,6 +15,9 @@
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "llvm/ADT/Optional.h"
namespace py = pybind11;
@ -25,6 +28,45 @@ namespace detail {
template <typename T>
struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
/// Casts object -> MlirContext.
template <> struct type_caster<MlirContext> {
PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
bool load(handle src, bool) {
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
value = mlirPythonCapsuleToContext(capsule.ptr());
if (mlirContextIsNull(value)) {
return false;
}
return true;
}
};
/// Casts object -> MlirModule.
template <> struct type_caster<MlirModule> {
PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
bool load(handle src, bool) {
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
value = mlirPythonCapsuleToModule(capsule.ptr());
if (mlirModuleIsNull(value)) {
return false;
}
return true;
}
};
/// Casts object -> MlirPassManager.
template <> struct type_caster<MlirPassManager> {
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
bool load(handle src, bool) {
auto capsule = src.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
value = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(value)) {
return false;
}
return true;
}
};
} // namespace detail
} // namespace pybind11

View File

@ -10,6 +10,8 @@
#include "pybind11/numpy.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "npcomp/Python/MlirIr.h"
#include "npcomp/Python/MlirPass.h"
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
@ -84,20 +86,22 @@ py::array wrapTensorAsArray(Ref<Tensor> tensor) {
/*base=*/std::move(pyTensor));
}
void npcomp::python::defineBackendRefJitModule(py::module m) {
m.def("build_backend_compilation_pipeline", [](PyPassManager &pm) {
JITModule::buildBackendCompilationPipeline(pm.passManager);
void npcomp::python::defineBackendRefJitModule(py::module &m) {
m.def("build_backend_compilation_pipeline", [](MlirPassManager capiPm) {
mlir::PassManager *pm = unwrap(capiPm);
JITModule::buildBackendCompilationPipeline(*pm);
});
py::class_<JITModule>(m, "JITModule")
.def_static(
"from_compiled_module",
[](PyModuleOp module, std::vector<std::string> pySharedLibs)
[](MlirModule capiModule, std::vector<std::string> pySharedLibs)
-> std::unique_ptr<JITModule> {
SmallVector<StringRef, 4> sharedLibs(pySharedLibs.begin(),
pySharedLibs.end());
auto jitModule = checkError(
JITModule::fromCompiledModule(module.moduleOp, sharedLibs),
"error creating JITModule: ");
auto module = unwrap(capiModule);
auto jitModule =
checkError(JITModule::fromCompiledModule(module, sharedLibs),
"error creating JITModule: ");
return jitModule;
},
py::arg("module"), py::arg("shared_libs"))

View File

@ -69,8 +69,8 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
loc, resultType, rhs, broadcastedShape);
Value binaryOpResult;
if (isa<tcf::AddOp>(op)) {
binaryOpResult = rewriter.create<AddFOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
binaryOpResult = rewriter.create<AddFOp>(loc, result.getType(),
lhsBroadcasted, rhsBroadcasted);
} else if (isa<tcf::MaxOp>(op)) {
// XXX: remove TCP dep
// XXX: remove TCP ops from TCP
@ -79,8 +79,8 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
binaryOpResult =
rewriter.create<SelectOp>(loc, pred, lhsBroadcasted, rhsBroadcasted);
} else if (isa<tcf::MulOp>(op)) {
binaryOpResult = rewriter.create<MulFOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
binaryOpResult = rewriter.create<MulFOp>(loc, result.getType(),
lhsBroadcasted, rhsBroadcasted);
} else {
op->dump();
llvm::report_fatal_error(

View File

@ -194,17 +194,17 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
//
// Also, converting to linalg herevopens up a lot of optimization
// opportunities.
pm.addPass(createConvertElementwiseToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
if (options.optimize) {
pm.addPass(createLinalgFusionOfTensorOpsPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<FuncOp>(createLinalgFusionOfTensorOpsPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
// Lower shape constraints before we enter tensor->memref conversion.
// That is, we expand shape.cstr_* ops to eager error handling code.
pm.addPass(createConvertShapeConstraintsPass());
pm.addNestedPass<FuncOp>(createConvertShapeConstraintsPass());
// Run shape canonicalizations. In particular, this erases shape.assuming,
// now that we have converted shape constraints.
// TODO: This is kind of ugly. Either we use pass options or a constructor
@ -227,12 +227,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// (tensor_load / tensor_to_memref) in the IR.
// Bufferize the TCP dialect.
pm.addPass(createTCPBufferizePass());
pm.addNestedPass<FuncOp>(createTCPBufferizePass());
// Lower tensor-valued constants to refback.global.
pm.addPass(createLowerConstantTensorsToMemrefPass());
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
// We need to resolve this to std.alloc which takes individual extents.
pm.addPass(createLowerAllocMemRefOpsPass());
pm.addNestedPass<FuncOp>(createLowerAllocMemRefOpsPass());
// Lower shape ops to std.
// TODO: This should in principle be moved before tensor->memref conversion.
// But some of the tensor->memref lowerings above use shape.get_extent. For
@ -241,8 +241,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
pm.addPass(createConvertShapeToStandardPass());
// Run some upstream bufferization passes to finish bufferization.
pm.addPass(createStdBufferizePass());
pm.addPass(createSCFBufferizePass());
pm.addNestedPass<FuncOp>(createStdBufferizePass());
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
pm.addPass(createLinalgBufferizePass());
pm.addPass(createFuncBufferizePass());
@ -252,8 +252,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// At this point, we have lots of loose stuff floating around from lowering,
// so it's a good time to do some general cleanups.
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
// --------------------------------------------------------------------------
@ -264,12 +264,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// Lower linalg ops to loops.
// TODO: Do some linalg optimizations like tiling here.
pm.addPass(createConvertLinalgToLoopsPass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
// Run a some cleanups.
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
// --------------------------------------------------------------------------
@ -277,7 +277,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// --------------------------------------------------------------------------
// Convert scf to std control flow in preparation for going to LLVM.
pm.addPass(createLowerToCFGPass());
pm.addNestedPass<FuncOp>(createLowerToCFGPass());
// Convert functions signatures and other constructs that interface with the
// runtime to the `refbackrt` dialect.
@ -291,8 +291,8 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// Although LLVM will clean everything up eventually, for the sake of IR
// clarity while still in MLIR, run some cleanups.
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
}
@ -305,6 +305,7 @@ void mlir::NPCOMP::createTCFRefBackendLoweringPipeline(
// case of invalid broadcasts.
//
// TCP does not. So we need to reify the broadcasting and error checking.
// These all run at the module level.
pm.addPass(createConvertTCFToStdPass());
pm.addPass(createConvertTCFToLinalgPass());
pm.addPass(createConvertTCFToTCPPass());

View File

@ -9,7 +9,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
# TODO: Make the runtime library work for windows.
${CMAKE_BINARY_DIR}/lib/libNPCOMPCompilerRuntimeShlib.so
${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/_refjit_resources/libNPCOMPCompilerRuntimeShlib.so
${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/generic/backend/libNPCOMPCompilerRuntimeShlib.so
)
add_dependencies(NPCOMPPythonResources
NPCOMPCompilerRuntimeShlib

View File

@ -12,6 +12,7 @@
#include "npcomp/Python/MlirInit.h"
#include "npcomp/Python/NpcompModule.h"
#include "npcomp/Python/PybindUtils.h"
#include "npcomp-c/Registration.h"
#include "llvm/Support/CommandLine.h"
#ifdef NPCOMP_ENABLE_REFJIT
@ -26,7 +27,7 @@ namespace mlir {
namespace npcomp {
namespace python {
void defineLLVMModule(pybind11::module m) {
static void defineLLVMModule(pybind11::module m) {
m.def("print_help_message", []() { llvm::cl::PrintHelpMessage(); });
m.def("add_option",
[](std::string name, llvm::Optional<std::string> value) {
@ -56,6 +57,12 @@ void defineLLVMModule(pybind11::module m) {
py::arg("name"));
}
static void defineGlobals(py::module &m) {
m.def("register_dialects", [](MlirContext context) {
npcompRegisterAllDialects(context);
});
}
PYBIND11_MODULE(_npcomp, m) {
// Guard the once init to happen once per process (vs module, which in
// mondo builds can happen multiple times).
@ -64,6 +71,8 @@ PYBIND11_MODULE(_npcomp, m) {
m.doc() = "Npcomp native python bindings";
// TODO: Retire the llvm, mlir, passes, and dialect modules in favor of
// upstream Python bindings.
auto llvm_m = m.def_submodule("llvm", "LLVM interop");
defineLLVMModule(llvm_m);
@ -81,6 +90,9 @@ PYBIND11_MODULE(_npcomp, m) {
auto npcomp_dialect = m.def_submodule("dialect", "NPComp custom dialects");
defineNpcompDialect(npcomp_dialect);
// Globals.
defineGlobals(m);
// Optional backend modules.
auto backend_m = m.def_submodule("backend", "Backend support");
(void)backend_m;

View File

@ -0,0 +1,66 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
_refjit = None
BACKEND_PASSES = (
"func(convert-scf-to-std)",
"func(canonicalize)",
"func(tcf-shape-refinement)",
)
def get_refjit():
"""Dynamically resolves the refjit backend native module."""
global _refjit
if _refjit is not None:
return _refjit
try:
from _npcomp.backend import refjit as imported_refjit
except ImportError:
raise ImportError(
"The npcomp native module was not compiled with refjit support")
_refjit = imported_refjit
return _refjit
def is_enabled() -> bool:
"""Returns whether the backend is enabled for the current build."""
try:
_get_refjit()
return True
except ImportError:
return False
def get_runtime_libs():
# The _refjit_resources directory is at the npcomp.compiler level.
resources_dir = os.path.join(os.path.dirname(__file__))
return [os.path.join(resources_dir, "libNPCOMPCompilerRuntimeShlib.so")]
class JitModuleInvoker:
"""Wrapper around a native JitModule for calling functions."""
def __init__(self, jit_module):
super().__init__()
self._jit_module = jit_module
def __getattr__(self, function_name):
return self.__getitem__(function_name)
def __getitem__(self, function_name):
def invoke(*args):
results = self._jit_module.invoke(function_name, args)
if len(results) == 1:
# De-tuple.
return results[0]
else:
return tuple(results)
invoke.__isnpcomp__ = True
return invoke

View File

@ -4,7 +4,11 @@
import os
from _npcomp import mlir
from mlir.ir import *
from mlir.passmanager import *
from _npcomp import register_dialects
from _npcomp import mlir as legacy_mlir
from npcomp.compiler.generic.backend import refjit as refjit_backend
from npcomp.compiler.utils import logging
__all__ = [
@ -13,45 +17,16 @@ __all__ = [
]
FRONTEND_PASSES = (
"npcomp-cpa-type-inference",
"func(npcomp-cpa-type-inference)",
"numpy-public-functions-to-tensor",
"convert-numpy-to-tcf",
"convert-scf-to-std",
"canonicalize",
"tcf-shape-refinement",
"func(convert-numpy-to-tcf)",
"func(convert-scf-to-std)",
"func(canonicalize)",
"func(tcf-shape-refinement)",
)
_refjit = None
def _get_refjit():
"""Dynamically resolves the refjit backend native module."""
global _refjit
if _refjit is not None:
return _refjit
try:
from _npcomp.backend import refjit as imported_refjit
except ImportError:
raise ImportError(
"The npcomp native module was not compiled with refjit support")
_refjit = imported_refjit
return _refjit
def is_enabled() -> bool:
"""Returns whether the backend is enabled for the current build."""
try:
_get_refjit()
return True
except ImportError:
return False
def get_runtime_libs():
# The _refjit_resources directory is at the npcomp.compiler level.
resources_dir = os.path.join(os.path.dirname(__file__), "..", "..",
"_refjit_resources")
return [os.path.join(resources_dir, "libNPCOMPCompilerRuntimeShlib.so")]
# Re-export.
is_enabled = refjit_backend.is_enabled
class CompilerBackend:
@ -59,37 +34,41 @@ class CompilerBackend:
def __init__(self):
super().__init__()
self._refjit = _get_refjit()
self._refjit = refjit_backend.get_refjit()
self._debug = logging.debug_enabled()
def compile(self, imported_ir_module: mlir.ir.ModuleOp):
def compile(self, legacy_imported_ir_module: legacy_mlir.ir.ModuleOp):
"""Compiles an imported module.
Args:
imported_ir_module: The MLIR module as imported from the ImportFrontend.
legacy_imported_ir_module: The MLIR module as imported from the
ImportFrontend.
Returns:
An opaque, backend specific module object that can be passed to load.
The object may actually be something more specific to the backend (i.e.
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
# Frontend.
pm = mlir.passes.PassManager(imported_ir_module.context)
pm.addPassPipelines(*FRONTEND_PASSES)
pm.run(imported_ir_module)
if self._debug:
logging.debug("Frontend IR:{}", imported_ir_module.to_asm())
# TODO: Once transitioned to new Python API, don't reparse the module.
with Context() as context:
register_dialects(context)
imported_module = Module.parse(legacy_imported_ir_module.to_asm())
# Frontend.
pm = PassManager.parse(",".join(FRONTEND_PASSES))
pm.run(imported_module)
if self._debug:
logging.debug("Frontend IR:{}", imported_module)
# Backend.
# Note that this is a separate pass manager purely to aid in debugging.
pm = mlir.passes.PassManager(imported_ir_module.context)
self._refjit.build_backend_compilation_pipeline(pm)
pm.run(imported_ir_module)
if self._debug:
logging.debug("Backend IR:{}", imported_ir_module.to_asm())
# Backend.
# Note that this is a separate pass manager purely to aid in debugging.
pm = PassManager()
self._refjit.build_backend_compilation_pipeline(pm)
pm.run(imported_module)
if self._debug:
logging.debug("Backend IR:{}", imported_module)
jit_module = self._refjit.JITModule.from_compiled_module(
imported_ir_module, get_runtime_libs())
imported_module, refjit_backend.get_runtime_libs())
return jit_module
def load(self, jit_module):
@ -97,25 +76,4 @@ class CompilerBackend:
Since this is a JIT instead of an AOT compiler,
"""
return JitModuleInvoker(jit_module)
class JitModuleInvoker:
"""Wrapper around a native JitModule for calling functions."""
def __init__(self, jit_module):
super().__init__()
self._jit_module = jit_module
def __getitem__(self, function_name):
def invoke(*args):
results = self._jit_module.invoke(function_name, args)
if len(results) == 1:
# De-tuple.
return results[0]
else:
return tuple(results)
invoke.__isnpcomp__ = True
return invoke
return refjit_backend.JitModuleInvoker(jit_module)

View File

@ -0,0 +1,69 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
from mlir.ir import *
from mlir.passmanager import *
from npcomp.compiler.generic.backend import refjit as refjit_backend
from npcomp.compiler.utils import logging
__all__ = [
"is_enabled",
"CompilerBackend",
]
FRONTEND_PASSES = ("func(aten-recognize-kernels)", "func(convert-aten-to-tcf)",
"numpy-public-functions-to-tensor", "canonicalize")
# Re-export.
is_enabled = refjit_backend.is_enabled
class CompilerBackend:
"""Main entry-point for the backend."""
def __init__(self):
super().__init__()
self._refjit = refjit_backend.get_refjit()
self._debug = logging.debug_enabled()
def compile(self, imported_module: Module):
"""Compiles an imported module.
Args:
imported_module: The MLIR module consisting of funcs in the torch
dialect.
Returns:
An opaque, backend specific module object that can be passed to load.
The object may actually be something more specific to the backend (i.e.
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
# TODO: Once transitioned to new Python API, don't reparse the module.
with Context() as context:
# Frontend.
pm = PassManager.parse(",".join(FRONTEND_PASSES))
pm.run(imported_module)
if self._debug:
logging.debug("Frontend IR:{}", imported_module)
# Backend.
# Note that this is a separate pass manager purely to aid in debugging.
pm = PassManager()
self._refjit.build_backend_compilation_pipeline(pm)
pm.run(imported_module)
if self._debug:
logging.debug("Backend IR:{}", imported_module)
jit_module = self._refjit.JITModule.from_compiled_module(
imported_module, refjit_backend.get_runtime_libs())
return jit_module
def load(self, jit_module):
"""Loads a compiled artifact into the runtime.
Since this is a JIT instead of an AOT compiler,
"""
return refjit_backend.JitModuleInvoker(jit_module)