mirror of https://github.com/llvm/torch-mlir
Wrap the IREE compiler flow in a one stop API.
parent
dc87e09b5a
commit
b811db4b76
|
@ -1,52 +1,29 @@
|
||||||
# RUN: %PYTHON %s
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
|
from npcomp.compiler.backend import iree
|
||||||
from npcomp.compiler.frontend import *
|
from npcomp.compiler.frontend import *
|
||||||
|
from npcomp.compiler import logging
|
||||||
from npcomp.compiler.target import *
|
from npcomp.compiler.target import *
|
||||||
|
|
||||||
# TODO: This should all exist in a high level API somewhere.
|
# TODO: This should all exist in a high level API somewhere.
|
||||||
from _npcomp import mlir
|
from _npcomp import mlir
|
||||||
from _npcomp.backend import iree as ireec
|
|
||||||
|
|
||||||
from pyiree import rt
|
|
||||||
|
logging.enable()
|
||||||
|
|
||||||
|
|
||||||
def compile_function(f):
|
def compile_function(f):
|
||||||
fe = ImportFrontend(target_factory=GenericTarget32)
|
fe = ImportFrontend(target_factory=GenericTarget32)
|
||||||
ir_f = fe.import_global_function(f)
|
fe.import_global_function(f)
|
||||||
|
compiler = iree.CompilerBackend()
|
||||||
input_m = fe.ir_module
|
vm_blob = compiler.compile(fe.ir_module)
|
||||||
# For easier debugging, split into to pass manager invocations.
|
loaded_m = compiler.load(vm_blob)
|
||||||
pm = mlir.passes.PassManager(input_m.context)
|
return loaded_m[f.__name__]
|
||||||
# TOOD: Have an API for this
|
|
||||||
pm.addPassPipelines(
|
|
||||||
"basicpy-type-inference", "convert-basicpy-to-std", "canonicalize")
|
|
||||||
pm.run(input_m)
|
|
||||||
print("INPUT MODULE:")
|
|
||||||
print(input_m.to_asm())
|
|
||||||
|
|
||||||
# Main IREE compiler.
|
|
||||||
pm = mlir.passes.PassManager(input_m.context)
|
|
||||||
ireec.build_flow_transform_pass_pipeline(pm)
|
|
||||||
ireec.build_hal_transform_pass_pipeline(pm)
|
|
||||||
ireec.build_vm_transform_pass_pipeline(pm)
|
|
||||||
pm.run(input_m)
|
|
||||||
print("VM MODULE:")
|
|
||||||
print(input_m.to_asm())
|
|
||||||
|
|
||||||
# Translate to VM bytecode flatbuffer.
|
|
||||||
vm_blob = ireec.translate_to_vm_bytecode(input_m)
|
|
||||||
print("VM BLOB: len =", len(vm_blob))
|
|
||||||
return vm_blob
|
|
||||||
|
|
||||||
|
|
||||||
|
@compile_function
|
||||||
def int_add(a: int, b: int):
|
def int_add(a: int, b: int):
|
||||||
return a + b
|
return a + b
|
||||||
|
|
||||||
vm_blob = compile_function(int_add)
|
result = int_add(5, 6)
|
||||||
m = rt.VmModule.from_flatbuffer(vm_blob)
|
assert result == 11
|
||||||
config = rt.Config("vmla")
|
|
||||||
ctx = rt.SystemContext(config=config)
|
|
||||||
ctx.add_module(m)
|
|
||||||
|
|
||||||
f = ctx.modules.module.int_add
|
|
||||||
print(f(5, 6))
|
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from _npcomp import mlir
|
||||||
|
from npcomp.compiler import logging
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"is_enabled",
|
||||||
|
"CompilerBackend",
|
||||||
|
]
|
||||||
|
|
||||||
|
_ireec = None
|
||||||
|
_ireert = None
|
||||||
|
_cached_config = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_iree():
|
||||||
|
"""Dynamically resolves the iree backend module."""
|
||||||
|
global _ireec
|
||||||
|
global _ireert
|
||||||
|
if _ireec is not None:
|
||||||
|
return _ireec, _ireert
|
||||||
|
try:
|
||||||
|
from _npcomp.backend import iree as imported_ireec
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The npcomp native module was not compiled with IREE support")
|
||||||
|
try:
|
||||||
|
from pyiree import rt as imported_rt
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("IREE runtime library not found (pyiree.rt)")
|
||||||
|
|
||||||
|
_ireec = imported_ireec
|
||||||
|
_ireert = imported_rt
|
||||||
|
return _ireec, _ireert
|
||||||
|
|
||||||
|
|
||||||
|
def is_enabled() -> bool:
|
||||||
|
"""Returns whether the backend is enabled for the current build."""
|
||||||
|
try:
|
||||||
|
_get_iree()
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class CompilerBackend:
|
||||||
|
"""Main entry-point for the backend."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._ireec, self._ireert = _get_iree()
|
||||||
|
self._debug = logging.debug_enabled()
|
||||||
|
|
||||||
|
def compile(self, imported_ir_module: mlir.ir.ModuleOp):
|
||||||
|
"""Compiles an imported module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imported_ir_module: The MLIR module as imported from the ImportFrontend.
|
||||||
|
Returns:
|
||||||
|
An opaque, backend specific module object that can be passed to load.
|
||||||
|
The object may actually be something more specific to the backend (i.e.
|
||||||
|
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||||
|
it is operated on by methods on this class.
|
||||||
|
"""
|
||||||
|
ireec = self._ireec
|
||||||
|
# For easier debugging, split into to pass manager invocations.
|
||||||
|
# Frontend.
|
||||||
|
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||||
|
self.add_frontend_passes(pm)
|
||||||
|
pm.run(imported_ir_module)
|
||||||
|
if self._debug:
|
||||||
|
logging.debug("Frontend IR:{}", imported_ir_module.to_asm())
|
||||||
|
# Backend.
|
||||||
|
pm = mlir.passes.PassManager(imported_ir_module.context)
|
||||||
|
self.add_backend_passes(pm)
|
||||||
|
pm.run(imported_ir_module)
|
||||||
|
if self._debug:
|
||||||
|
logging.debug("Backend IR:{}", imported_ir_module.to_asm())
|
||||||
|
# Translation/serialization.
|
||||||
|
vm_blob = ireec.translate_to_vm_bytecode(imported_ir_module)
|
||||||
|
if self._debug:
|
||||||
|
logging.debug("Compiled VM BLOB size={}", len(vm_blob))
|
||||||
|
return vm_blob
|
||||||
|
|
||||||
|
def load(self, vm_blob):
|
||||||
|
"""Loads a compiled artifact into the runtime.
|
||||||
|
|
||||||
|
This is meant as a simple mechanism for testing and is not optimized or
|
||||||
|
highly parameterized. It loads a compiled result into a new runtime
|
||||||
|
instance and returns an object that exposes a python function for each
|
||||||
|
public function compiled in the imported_ir_module that was compiled.
|
||||||
|
"""
|
||||||
|
ireert = self._ireert
|
||||||
|
m = ireert.VmModule.from_flatbuffer(vm_blob)
|
||||||
|
global _cached_config
|
||||||
|
if not _cached_config:
|
||||||
|
# TODO: Need to make the configuration more flexible.
|
||||||
|
_cached_config = ireert.Config(driver_name="vmla")
|
||||||
|
ctx = ireert.SystemContext(config=_cached_config)
|
||||||
|
ctx.add_module(m)
|
||||||
|
# TODO: The implicit tying of the 'module' name has got to go.
|
||||||
|
return ctx.modules.module
|
||||||
|
|
||||||
|
def add_frontend_passes(self, pm: mlir.passes.PassManager):
|
||||||
|
"""Adds passes needed for legalizing from an imported form.
|
||||||
|
|
||||||
|
While an arbitrary distinction, the passes added here are more about
|
||||||
|
legalizing the basicpy and numpy dialects in preparation for performing
|
||||||
|
backend compilation. They are separated to aid debugging.
|
||||||
|
"""
|
||||||
|
# TOOD: Have an API for this
|
||||||
|
pm.addPassPipelines("basicpy-type-inference", "convert-basicpy-to-std",
|
||||||
|
"canonicalize")
|
||||||
|
|
||||||
|
def add_backend_passes(self, pm: mlir.passes.PassManager):
|
||||||
|
"""Adds passes for full backend compilation.
|
||||||
|
|
||||||
|
These passes are added after the frontend passes.
|
||||||
|
"""
|
||||||
|
ireec = self._ireec
|
||||||
|
ireec.build_flow_transform_pass_pipeline(pm)
|
||||||
|
ireec.build_hal_transform_pass_pipeline(pm)
|
||||||
|
ireec.build_vm_transform_pass_pipeline(pm)
|
|
@ -6,12 +6,21 @@ import os
|
||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
__all__ = ["debug"]
|
__all__ = ["debug", "debug_enabled", "enable"]
|
||||||
|
|
||||||
_ENABLED = "NPCOMP_DEBUG" in os.environ
|
_ENABLED = "NPCOMP_DEBUG" in os.environ
|
||||||
_formatter = string.Formatter()
|
_formatter = string.Formatter()
|
||||||
|
|
||||||
|
|
||||||
|
def enable():
|
||||||
|
global _ENABLED
|
||||||
|
_ENABLED = True
|
||||||
|
|
||||||
|
|
||||||
|
def debug_enabled():
|
||||||
|
return _ENABLED
|
||||||
|
|
||||||
|
|
||||||
def debug(format_string, *args, **kwargs):
|
def debug(format_string, *args, **kwargs):
|
||||||
if not _ENABLED:
|
if not _ENABLED:
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue