mirror of https://github.com/llvm/torch-mlir
Wrap the IREE compiler flow in a one stop API.
parent
dc87e09b5a
commit
b811db4b76
|
@ -1,52 +1,29 @@
|
|||
# RUN: %PYTHON %s
|
||||
|
||||
from npcomp.compiler.backend import iree
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import logging
|
||||
from npcomp.compiler.target import *
|
||||
|
||||
# TODO: This should all exist in a high level API somewhere.
|
||||
from _npcomp import mlir
|
||||
from _npcomp.backend import iree as ireec
|
||||
|
||||
from pyiree import rt
|
||||
|
||||
logging.enable()
|
||||
|
||||
|
||||
def compile_function(f):
|
||||
fe = ImportFrontend(target_factory=GenericTarget32)
|
||||
ir_f = fe.import_global_function(f)
|
||||
|
||||
input_m = fe.ir_module
|
||||
# For easier debugging, split into to pass manager invocations.
|
||||
pm = mlir.passes.PassManager(input_m.context)
|
||||
# TOOD: Have an API for this
|
||||
pm.addPassPipelines(
|
||||
"basicpy-type-inference", "convert-basicpy-to-std", "canonicalize")
|
||||
pm.run(input_m)
|
||||
print("INPUT MODULE:")
|
||||
print(input_m.to_asm())
|
||||
|
||||
# Main IREE compiler.
|
||||
pm = mlir.passes.PassManager(input_m.context)
|
||||
ireec.build_flow_transform_pass_pipeline(pm)
|
||||
ireec.build_hal_transform_pass_pipeline(pm)
|
||||
ireec.build_vm_transform_pass_pipeline(pm)
|
||||
pm.run(input_m)
|
||||
print("VM MODULE:")
|
||||
print(input_m.to_asm())
|
||||
|
||||
# Translate to VM bytecode flatbuffer.
|
||||
vm_blob = ireec.translate_to_vm_bytecode(input_m)
|
||||
print("VM BLOB: len =", len(vm_blob))
|
||||
return vm_blob
|
||||
fe.import_global_function(f)
|
||||
compiler = iree.CompilerBackend()
|
||||
vm_blob = compiler.compile(fe.ir_module)
|
||||
loaded_m = compiler.load(vm_blob)
|
||||
return loaded_m[f.__name__]
|
||||
|
||||
|
||||
@compile_function
|
||||
def int_add(a: int, b: int):
|
||||
return a + b
|
||||
|
||||
vm_blob = compile_function(int_add)
|
||||
m = rt.VmModule.from_flatbuffer(vm_blob)
|
||||
config = rt.Config("vmla")
|
||||
ctx = rt.SystemContext(config=config)
|
||||
ctx.add_module(m)
|
||||
|
||||
f = ctx.modules.module.int_add
|
||||
print(f(5, 6))
|
||||
result = int_add(5, 6)
|
||||
assert result == 11
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from _npcomp import mlir
|
||||
from npcomp.compiler import logging
|
||||
|
||||
__all__ = [
|
||||
"is_enabled",
|
||||
"CompilerBackend",
|
||||
]
|
||||
|
||||
_ireec = None
|
||||
_ireert = None
|
||||
_cached_config = None
|
||||
|
||||
|
||||
def _get_iree():
|
||||
"""Dynamically resolves the iree backend module."""
|
||||
global _ireec
|
||||
global _ireert
|
||||
if _ireec is not None:
|
||||
return _ireec, _ireert
|
||||
try:
|
||||
from _npcomp.backend import iree as imported_ireec
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The npcomp native module was not compiled with IREE support")
|
||||
try:
|
||||
from pyiree import rt as imported_rt
|
||||
except ImportError:
|
||||
raise ImportError("IREE runtime library not found (pyiree.rt)")
|
||||
|
||||
_ireec = imported_ireec
|
||||
_ireert = imported_rt
|
||||
return _ireec, _ireert
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""Returns whether the backend is enabled for the current build."""
|
||||
try:
|
||||
_get_iree()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._ireec, self._ireert = _get_iree()
|
||||
self._debug = logging.debug_enabled()
|
||||
|
||||
def compile(self, imported_ir_module: mlir.ir.ModuleOp):
|
||||
"""Compiles an imported module.
|
||||
|
||||
Args:
|
||||
imported_ir_module: The MLIR module as imported from the ImportFrontend.
|
||||
Returns:
|
||||
An opaque, backend specific module object that can be passed to load.
|
||||
The object may actually be something more specific to the backend (i.e.
|
||||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
ireec = self._ireec
|
||||
# For easier debugging, split into to pass manager invocations.
|
||||
# Frontend.
|
||||
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||
self.add_frontend_passes(pm)
|
||||
pm.run(imported_ir_module)
|
||||
if self._debug:
|
||||
logging.debug("Frontend IR:{}", imported_ir_module.to_asm())
|
||||
# Backend.
|
||||
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||
self.add_backend_passes(pm)
|
||||
pm.run(imported_ir_module)
|
||||
if self._debug:
|
||||
logging.debug("Backend IR:{}", imported_ir_module.to_asm())
|
||||
# Translation/serialization.
|
||||
vm_blob = ireec.translate_to_vm_bytecode(imported_ir_module)
|
||||
if self._debug:
|
||||
logging.debug("Compiled VM BLOB size={}", len(vm_blob))
|
||||
return vm_blob
|
||||
|
||||
def load(self, vm_blob):
|
||||
"""Loads a compiled artifact into the runtime.
|
||||
|
||||
This is meant as a simple mechanism for testing and is not optimized or
|
||||
highly parameterized. It loads a compiled result into a new runtime
|
||||
instance and returns an object that exposes a python function for each
|
||||
public function compiled in the imported_ir_module that was compiled.
|
||||
"""
|
||||
ireert = self._ireert
|
||||
m = ireert.VmModule.from_flatbuffer(vm_blob)
|
||||
global _cached_config
|
||||
if not _cached_config:
|
||||
# TODO: Need to make the configuration more flexible.
|
||||
_cached_config = ireert.Config(driver_name="vmla")
|
||||
ctx = ireert.SystemContext(config=_cached_config)
|
||||
ctx.add_module(m)
|
||||
# TODO: The implicit tying of the 'module' name has got to go.
|
||||
return ctx.modules.module
|
||||
|
||||
def add_frontend_passes(self, pm: mlir.passes.PassManager):
|
||||
"""Adds passes needed for legalizing from an imported form.
|
||||
|
||||
While an arbitrary distinction, the passes added here are more about
|
||||
legalizing the basicpy and numpy dialects in preparation for performing
|
||||
backend compilation. They are separated to aid debugging.
|
||||
"""
|
||||
# TOOD: Have an API for this
|
||||
pm.addPassPipelines("basicpy-type-inference", "convert-basicpy-to-std",
|
||||
"canonicalize")
|
||||
|
||||
def add_backend_passes(self, pm: mlir.passes.PassManager):
|
||||
"""Adds passes for full backend compilation.
|
||||
|
||||
These passes are added after the frontend passes.
|
||||
"""
|
||||
ireec = self._ireec
|
||||
ireec.build_flow_transform_pass_pipeline(pm)
|
||||
ireec.build_hal_transform_pass_pipeline(pm)
|
||||
ireec.build_vm_transform_pass_pipeline(pm)
|
|
@ -6,12 +6,21 @@ import os
|
|||
import string
|
||||
import sys
|
||||
|
||||
__all__ = ["debug"]
|
||||
__all__ = ["debug", "debug_enabled", "enable"]
|
||||
|
||||
_ENABLED = "NPCOMP_DEBUG" in os.environ
|
||||
_formatter = string.Formatter()
|
||||
|
||||
|
||||
def enable():
|
||||
global _ENABLED
|
||||
_ENABLED = True
|
||||
|
||||
|
||||
def debug_enabled():
|
||||
return _ENABLED
|
||||
|
||||
|
||||
def debug(format_string, *args, **kwargs):
|
||||
if not _ENABLED:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue