Wrap the IREE compiler flow in a one stop API.

pull/1/head
Stella Laurenzo 2020-06-19 17:17:22 -07:00
parent dc87e09b5a
commit b811db4b76
4 changed files with 147 additions and 36 deletions

View File

@ -1,52 +1,29 @@
# RUN: %PYTHON %s
from npcomp.compiler.backend import iree
from npcomp.compiler.frontend import *
from npcomp.compiler import logging
from npcomp.compiler.target import *
# TODO: This should all exist in a high level API somewhere.
from _npcomp import mlir
from _npcomp.backend import iree as ireec
from pyiree import rt
logging.enable()
def compile_function(f):
fe = ImportFrontend(target_factory=GenericTarget32)
ir_f = fe.import_global_function(f)
input_m = fe.ir_module
# For easier debugging, split into to pass manager invocations.
pm = mlir.passes.PassManager(input_m.context)
# TOOD: Have an API for this
pm.addPassPipelines(
"basicpy-type-inference", "convert-basicpy-to-std", "canonicalize")
pm.run(input_m)
print("INPUT MODULE:")
print(input_m.to_asm())
# Main IREE compiler.
pm = mlir.passes.PassManager(input_m.context)
ireec.build_flow_transform_pass_pipeline(pm)
ireec.build_hal_transform_pass_pipeline(pm)
ireec.build_vm_transform_pass_pipeline(pm)
pm.run(input_m)
print("VM MODULE:")
print(input_m.to_asm())
# Translate to VM bytecode flatbuffer.
vm_blob = ireec.translate_to_vm_bytecode(input_m)
print("VM BLOB: len =", len(vm_blob))
return vm_blob
fe.import_global_function(f)
compiler = iree.CompilerBackend()
vm_blob = compiler.compile(fe.ir_module)
loaded_m = compiler.load(vm_blob)
return loaded_m[f.__name__]
@compile_function
def int_add(a: int, b: int):
return a + b
vm_blob = compile_function(int_add)
m = rt.VmModule.from_flatbuffer(vm_blob)
config = rt.Config("vmla")
ctx = rt.SystemContext(config=config)
ctx.add_module(m)
f = ctx.modules.module.int_add
print(f(5, 6))
result = int_add(5, 6)
assert result == 11

View File

@ -0,0 +1,125 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from _npcomp import mlir
from npcomp.compiler import logging
__all__ = [
"is_enabled",
"CompilerBackend",
]
_ireec = None
_ireert = None
_cached_config = None
def _get_iree():
"""Dynamically resolves the iree backend module."""
global _ireec
global _ireert
if _ireec is not None:
return _ireec, _ireert
try:
from _npcomp.backend import iree as imported_ireec
except ImportError:
raise ImportError(
"The npcomp native module was not compiled with IREE support")
try:
from pyiree import rt as imported_rt
except ImportError:
raise ImportError("IREE runtime library not found (pyiree.rt)")
_ireec = imported_ireec
_ireert = imported_rt
return _ireec, _ireert
def is_enabled() -> bool:
"""Returns whether the backend is enabled for the current build."""
try:
_get_iree()
return True
except ImportError:
return False
class CompilerBackend:
"""Main entry-point for the backend."""
def __init__(self):
super().__init__()
self._ireec, self._ireert = _get_iree()
self._debug = logging.debug_enabled()
def compile(self, imported_ir_module: mlir.ir.ModuleOp):
"""Compiles an imported module.
Args:
imported_ir_module: The MLIR module as imported from the ImportFrontend.
Returns:
An opaque, backend specific module object that can be passed to load.
The object may actually be something more specific to the backend (i.e.
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
ireec = self._ireec
# For easier debugging, split into to pass manager invocations.
# Frontend.
pm = mlir.passes.PassManager(imported_ir_module.context)
self.add_frontend_passes(pm)
pm.run(imported_ir_module)
if self._debug:
logging.debug("Frontend IR:{}", imported_ir_module.to_asm())
# Backend.
pm = mlir.passes.PassManager(imported_ir_module.context)
self.add_backend_passes(pm)
pm.run(imported_ir_module)
if self._debug:
logging.debug("Backend IR:{}", imported_ir_module.to_asm())
# Translation/serialization.
vm_blob = ireec.translate_to_vm_bytecode(imported_ir_module)
if self._debug:
logging.debug("Compiled VM BLOB size={}", len(vm_blob))
return vm_blob
def load(self, vm_blob):
"""Loads a compiled artifact into the runtime.
This is meant as a simple mechanism for testing and is not optimized or
highly parameterized. It loads a compiled result into a new runtime
instance and returns an object that exposes a python function for each
public function compiled in the imported_ir_module that was compiled.
"""
ireert = self._ireert
m = ireert.VmModule.from_flatbuffer(vm_blob)
global _cached_config
if not _cached_config:
# TODO: Need to make the configuration more flexible.
_cached_config = ireert.Config(driver_name="vmla")
ctx = ireert.SystemContext(config=_cached_config)
ctx.add_module(m)
# TODO: The implicit tying of the 'module' name has got to go.
return ctx.modules.module
def add_frontend_passes(self, pm: mlir.passes.PassManager):
"""Adds passes needed for legalizing from an imported form.
While an arbitrary distinction, the passes added here are more about
legalizing the basicpy and numpy dialects in preparation for performing
backend compilation. They are separated to aid debugging.
"""
# TOOD: Have an API for this
pm.addPassPipelines("basicpy-type-inference", "convert-basicpy-to-std",
"canonicalize")
def add_backend_passes(self, pm: mlir.passes.PassManager):
"""Adds passes for full backend compilation.
These passes are added after the frontend passes.
"""
ireec = self._ireec
ireec.build_flow_transform_pass_pipeline(pm)
ireec.build_hal_transform_pass_pipeline(pm)
ireec.build_vm_transform_pass_pipeline(pm)

View File

@ -6,12 +6,21 @@ import os
import string
import sys
__all__ = ["debug"]
__all__ = ["debug", "debug_enabled", "enable"]
_ENABLED = "NPCOMP_DEBUG" in os.environ
_formatter = string.Formatter()
def enable():
global _ENABLED
_ENABLED = True
def debug_enabled():
return _ENABLED
def debug(format_string, *args, **kwargs):
if not _ENABLED:
return