diff --git a/backend_test/iree/Sample/simple_invoke.py b/backend_test/iree/Sample/simple_invoke.py index d7c32e4bc..c4eac3123 100644 --- a/backend_test/iree/Sample/simple_invoke.py +++ b/backend_test/iree/Sample/simple_invoke.py @@ -3,6 +3,7 @@ from npcomp.compiler.backend import iree from npcomp.compiler.frontend import * from npcomp.compiler import logging +from npcomp.compiler import test_config from npcomp.compiler.target import * # TODO: This should all exist in a high level API somewhere. @@ -13,7 +14,8 @@ logging.enable() def compile_function(f): - fe = ImportFrontend(target_factory=GenericTarget32) + fe = ImportFrontend(config=test_config.create_test_config( + target_factory=GenericTarget32)) fe.import_global_function(f) compiler = iree.CompilerBackend() vm_blob = compiler.compile(fe.ir_module) diff --git a/pytest/Compiler/binary_expressions.py b/pytest/Compiler/binary_expressions.py index 26a71336b..c48369850 100644 --- a/pytest/Compiler/binary_expressions.py +++ b/pytest/Compiler/binary_expressions.py @@ -1,14 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # Full checking for add. Others just check validity. diff --git a/pytest/Compiler/booleans.py b/pytest/Compiler/booleans.py index c3e1b2c8d..7382008dc 100644 --- a/pytest/Compiler/booleans.py +++ b/pytest/Compiler/booleans.py @@ -1,14 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @logical_and diff --git a/pytest/Compiler/comparisons.py b/pytest/Compiler/comparisons.py index 3ea2c040f..b46fdc12f 100644 --- a/pytest/Compiler/comparisons.py +++ b/pytest/Compiler/comparisons.py @@ -1,14 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @binary_lt_ diff --git a/pytest/Compiler/constants.py b/pytest/Compiler/constants.py index cddcf136a..688c39178 100644 --- a/pytest/Compiler/constants.py +++ b/pytest/Compiler/constants.py @@ -1,14 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @integer_constants diff --git a/pytest/Compiler/constants32.py b/pytest/Compiler/constants32.py index 233007dc4..df0f657cf 100644 --- a/pytest/Compiler/constants32.py +++ b/pytest/Compiler/constants32.py @@ -2,16 +2,11 @@ # Subset of constant tests which verify against a GenericTarget32. -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config from npcomp.compiler.target import * - -def import_global(f): - fe = ImportFrontend(target_factory=GenericTarget32) - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator( + target_factory=GenericTarget32) # CHECK-LABEL: func @integer_constants diff --git a/pytest/Compiler/partial_eval_getattr.py b/pytest/Compiler/partial_eval_getattr.py index 45a389a56..ee8a04e0f 100644 --- a/pytest/Compiler/partial_eval_getattr.py +++ b/pytest/Compiler/partial_eval_getattr.py @@ -2,15 +2,9 @@ import collections import math -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @module_constant diff --git a/pytest/Compiler/primitive_ops_to_std.py b/pytest/Compiler/primitive_ops_to_std.py index ef7f21795..88385c4b4 100644 --- a/pytest/Compiler/primitive_ops_to_std.py +++ b/pytest/Compiler/primitive_ops_to_std.py @@ -1,15 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-type-inference -convert-basicpy-to-std | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * - - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +from npcomp.compiler import test_config +import_global = test_config.create_import_dump_decorator() ################################################################################ # Integer tests diff --git a/pytest/Compiler/resolve_const.py b/pytest/Compiler/resolve_const.py index 535bef8eb..2713fe9a8 100644 --- a/pytest/Compiler/resolve_const.py +++ b/pytest/Compiler/resolve_const.py @@ -1,16 +1,9 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail """Module docstring.""" -from npcomp.compiler.frontend import * - - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +from npcomp.compiler import test_config +import_global = test_config.create_import_dump_decorator() OUTER_ONE = 1 OUTER_STRING = "Hello" diff --git a/pytest/Compiler/structure.py b/pytest/Compiler/structure.py index 090bddcac..849a0d0da 100644 --- a/pytest/Compiler/structure.py +++ b/pytest/Compiler/structure.py @@ -1,14 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @positional_args diff --git a/pytest/Compiler/template_call.py b/pytest/Compiler/template_call.py index acef7dee9..c49717133 100644 --- a/pytest/Compiler/template_call.py +++ b/pytest/Compiler/template_call.py @@ -1,15 +1,9 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail import math -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print("// -----") - print(fe.ir_module.to_asm()) - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @call_ceil_positional diff --git a/pytest/Compiler/type_inference.py b/pytest/Compiler/type_inference.py index 57d5032ed..ec732a4ce 100644 --- a/pytest/Compiler/type_inference.py +++ b/pytest/Compiler/type_inference.py @@ -1,14 +1,8 @@ # RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-type-inference | FileCheck %s --dump-input=fail -from npcomp.compiler.frontend import * +from npcomp.compiler import test_config - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print(fe.ir_module.to_asm()) - print("// -----") - return f +import_global = test_config.create_import_dump_decorator() # CHECK-LABEL: func @arithmetic_expression diff --git a/python/npcomp/compiler/environment.py b/python/npcomp/compiler/environment.py deleted file mode 100644 index 3f85e5486..000000000 --- a/python/npcomp/compiler/environment.py +++ /dev/null @@ -1,502 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from collections import namedtuple -from enum import Enum -import inspect -import sys -from typing import Optional, Union - -from _npcomp.mlir import ir - -from . import logging -from .py_value_utils import * -from .target import * - -__all__ = [ - "BuiltinsValueCoder", - "Environment", - "LiveValueRef", - "NameReference", - "NameResolver", - "PartialEvalResult", - "PartialEvalType", - "PartialEvalHook", - "ResolveAttrLiveValueRef", - "ValueCoder", - "ValueCoderChain", -] - -_Unspec = object() - -################################################################################ -# Interfaces and base classes -################################################################################ - - -class ValueCoder: - """Encodes values in various ways. - - Instances are designed to be daisy-chained and should ignore types that they - don't understand. Functions return NotImplemented if they cannot handle a - case locally. - """ - __slots__ = [] - - def create_const(self, env: "Environment", py_value): - return NotImplemented - - -class NameReference: - """Abstract base class for performing operations on a name.""" - __slots__ = [ - "name", - ] - - def __init__(self, name): - super().__init__() - self.name = name - - def load(self, env: "Environment", - ir_h: ir.DialectHelper) -> "PartialEvalResult": - """Loads the IR Value associated with the name. - - The load may either be direct, returning an existing value or - side-effecting, causing a read from an external context. - - Args: - ir_h: The dialect helper used to emit code. - Returns: - A partial evaluation result. - """ - return PartialEvalResult.not_evaluated() - - def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper): - """Stores a new value into the name. - - A subsequent call to 'load' should yield the same value, subject to - typing constraints on value equality. - - Args: - value: The new value to store into the name. - ir_h: The dialect helper used to emit code. - Raises: - NotImplementedError if store is not supported for this name. - """ - raise NotImplementedError() - - -class NameResolver: - """Abstract base class that can resolve a name. - - Name resolvers are typically stacked. - """ - - def checked_lookup(self, name): - ref = self.lookup(name) - assert ref is not None, "Lookup of name {} is required".format(name) - return ref - - def lookup(self, name) -> Optional[NameReference]: - return None - - -################################################################################ -# Partial evaluation -# When the compiler is extracting from a running program, it is likely that -# evaluations produce live values which can be further partially evaluated -# at import time, in the context of the running instance (versus emitting -# program IR to do so). This behavior is controlled through a PartialEvalHook -# on the environment. -################################################################################ - - -class PartialEvalType(Enum): - # Could not be evaluated immediately and the operation should be - # code-generated. yields NotImplemented. - NOT_EVALUATED = 0 - - # Yields a LiveValueRef - YIELDS_LIVE_VALUE = 1 - - # Yields an IR value - YIELDS_IR_VALUE = 2 - - # Evaluation yielded an error (yields contains exc_info from sys.exc_info()). - ERROR = 3 - - -class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")): - """Encapsulates the result of a partial evaluation.""" - - @classmethod - def not_evaluated(cls): - return cls(PartialEvalType.NOT_EVALUATED, NotImplemented) - - @classmethod - def yields_live_value(cls, live_value): - assert isinstance(live_value, LiveValueRef) - return cls(PartialEvalType.YIELDS_LIVE_VALUE, live_value) - - @classmethod - def yields_ir_value(cls, ir_value): - assert isinstance(ir_value, ir.Value) - return cls(PartialEvalType.YIELDS_IR_VALUE, ir_value) - - @classmethod - def error(cls): - return cls(PartialEvalType.ERROR, sys.exc_info()) - - @classmethod - def error_message(cls, message): - try: - raise RuntimeError(message) - except RuntimeError: - return cls.error() - - -class LiveValueRef: - """Wraps a live value from the containing environment. - - Typically, when expressions encounter a live value, a limited number of - partial evaluations can be done against it in place (versus emitting the code - to import it and perform the operation). This default base class will not - perform any static evaluations. - """ - __slots__ = [ - "live_value", - ] - - def __init__(self, live_value): - super().__init__() - self.live_value = live_value - - def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult: - """Gets a named attribute from the live value.""" - return PartialEvalResult.not_evaluated() - - def resolve_call(self, env: "Environment", args, - keywords) -> PartialEvalResult: - """Resolves a function call given 'args' and 'keywords'.""" - return PartialEvalResult.not_evaluated() - - def __repr__(self): - return "MacroValueRef({}, {})".format(self.__class__.__name__, - self.live_value) - - -class ResolveAttrLiveValueRef(LiveValueRef): - """Custom MacroValueRef that will resolve attributes via getattr.""" - __slots__ = [] - - def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult: - logging.debug("RESOLVE_GETATTR '{}' on {}".format(attr_name, - self.live_value)) - try: - attr_py_value = getattr(self.live_value, attr_name) - except: - return PartialEvalResult.error() - return env.partial_eval_hook.resolve(attr_py_value) - - -class TemplateCallLiveValueRef(LiveValueRef): - """Custom LiveValueRef that resolves calls to a func_template_call op.""" - __slots__ = ["callee_name"] - - def __init__(self, callee_name, live_value): - super().__init__(live_value) - self.callee_name = callee_name - - def resolve_call(self, env: "Environment", args, - keywords) -> PartialEvalResult: - linear_args = list(args) - kw_arg_names = [] - for kw_name, kw_value in keywords: - kw_arg_names.append(kw_name) - linear_args.append(kw_value) - - ir_h = env.ir_h - result_ir_value = ir_h.basicpy_func_template_call_op( - result_type=ir_h.basicpy_UnknownType, - callee_symbol=self.callee_name, - args=linear_args, - arg_names=kw_arg_names).result - return PartialEvalResult.yields_ir_value(result_ir_value) - - -class PartialEvalHook: - """Owned by an environment to customize partial evaluation.""" - __slots__ = [ - "_value_map", - ] - - def __init__(self): - super().__init__() - self._value_map = PyValueMap() - - def resolve(self, py_value) -> PartialEvalResult: - """Performs partial evaluation on a python value.""" - binding = self._value_map.lookup(py_value) - if binding is None: - logging.debug("PARTIAL EVAL RESOLVE {}: Passthrough", py_value) - return PartialEvalResult.yields_live_value(LiveValueRef(py_value)) - if isinstance(binding, LiveValueRef): - logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding) - return PartialEvalResult.yields_live_value(binding) - if isinstance(binding, PartialEvalResult): - return binding - # Attempt to call. - try: - binding = binding(py_value) - assert isinstance(binding, PartialEvalResult), ( - "Expected PartialEvalResult but got {}".format(binding)) - logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding) - return binding - except: - return PartialEvalResult.error() - - def _bind(self, - binding, - *, - for_ref=_Unspec, - for_type=_Unspec, - for_predicate=_Unspec): - if for_ref is not _Unspec: - self._value_map.bind_reference(for_ref, binding) - elif for_type is not _Unspec: - self._value_map.bind_type(for_type, binding) - elif for_predicate is not _Unspec: - self._value_map.bind_predicate(for_predicate, binding) - else: - raise ValueError( - "Must specify one of 'for_ref', 'for_type' or 'for_predicate") - - def enable_getattr(self, **kwargs): - """Enables partial evaluation of getattr.""" - self._bind( - lambda pv: PartialEvalResult.yields_live_value( - ResolveAttrLiveValueRef(pv)), **kwargs) - - def enable_template_call(self, callee_name, **kwargs): - """"Enables a global template call.""" - self._bind( - lambda pv: PartialEvalResult.yields_live_value( - TemplateCallLiveValueRef(callee_name, pv)), **kwargs) - - -################################################################################ -# Environment -# Top level instance encapsulating access to runtime state. -################################################################################ - - -class Environment(NameResolver): - """Manages access to the environment of a code region. - - This encapsulates name lookup, access to the containing module, etc. - """ - __slots__ = [ - "ir_h", - "_name_resolvers", - "target", - "value_coder", - "partial_eval_hook", - ] - - def __init__(self, - ir_h: ir.DialectHelper, - *, - target: Target, - name_resolvers=(), - value_coder, - partial_eval_hook=None): - super().__init__() - self.ir_h = ir_h - self.target = target - self._name_resolvers = name_resolvers - self.value_coder = value_coder - self.partial_eval_hook = partial_eval_hook if partial_eval_hook else PartialEvalHook( - ) - - @classmethod - def for_const_global_function(cls, ir_h: ir.DialectHelper, f, *, - parameter_bindings, **kwargs): - """Helper to generate an environment for a global function. - - This is a helper for the very common case and will be wholly insufficient - for advanced cases, including mutable global state, closures, etc. - Globals from the module are considered immutable. - """ - try: - code = f.__code__ - globals_dict = f.__globals__ - builtins_module = globals_dict["__builtins__"] - except AttributeError: - assert False, ( - "Function {} does not have required user-defined function attributes". - format(f)) - - # Locals resolver. - # Note that co_varnames should include both parameter and local names. - locals_resolver = LocalNameResolver(code.co_varnames) - resolvers = ( - locals_resolver, - ConstModuleNameResolver(globals_dict, as_dict=True), - ConstModuleNameResolver(builtins_module), - ) - env = cls(ir_h, name_resolvers=resolvers, **kwargs) - - # Bind parameters. - for name, value in parameter_bindings: - logging.debug("STORE PARAM: {} <- {}", name, value) - locals_resolver.checked_lookup(name).store(env, value) - return env - - def lookup(self, name) -> Optional[NameReference]: - for resolver in self._name_resolvers: - ref = resolver.lookup(name) - if ref is not None: - return ref - return None - - -################################################################################ -# Standard name resolvers -################################################################################ - - -class LocalNameReference(NameReference): - """Holds an association between a name and SSA value.""" - __slots__ = [ - "_current_value", - ] - - def __init__(self, name, initial_value=None): - super().__init__(name) - self._current_value = initial_value - - def load(self, env: "Environment") -> PartialEvalResult: - if self._current_value is None: - return PartialEvalResult.error_message( - "Attempt to access local '{}' before assignment".format(self.name)) - return PartialEvalResult.yields_ir_value(self._current_value) - - def store(self, env: "Environment", value: ir.Value): - self._current_value = value - - def __repr__(self): - return "".format(self.name) - - -class LocalNameResolver(NameResolver): - """Resolves names in a local cache of SSA values. - - This is used to manage locals and arguments (that are not referenced through - a closure). - """ - __slots__ = [ - "_name_refs", - ] - - def __init__(self, names): - super().__init__() - self._name_refs = {name: LocalNameReference(name) for name in names} - - def lookup(self, name) -> Optional[NameReference]: - return self._name_refs.get(name) - - -class ConstNameReference(NameReference): - """Represents a name/value mapping that will emit as a constant.""" - __slots__ = [ - "_py_value", - ] - - def __init__(self, name, py_value): - super().__init__(name) - self._py_value = py_value - - def load(self, env: "Environment") -> PartialEvalResult: - return env.partial_eval_hook.resolve(self._py_value) - - def __repr__(self): - return "".format(self.name, self._py_value) - - -class ConstModuleNameResolver(NameResolver): - """Resolves names from a module by treating them as immutable and loading - them as constants into a function scope. - """ - __slots__ = [ - "_as_dict", - "module", - ] - - def __init__(self, module, *, as_dict=False): - super().__init__() - self.module = module - self._as_dict = as_dict - - def lookup(self, name) -> Optional[NameReference]: - if self._as_dict: - if name in self.module: - py_value = self.module[name] - else: - return None - else: - try: - py_value = getattr(self.module, name) - except AttributeError: - return None - return ConstNameReference(name, py_value) - - -################################################################################ -# Standard value coders -################################################################################ - - -class BuiltinsValueCoder: - """Value coder for builtin python types.""" - __slots__ = [] - - def create_const(self, env: "Environment", py_value): - ir_h = env.ir_h - ir_c = ir_h.context - if py_value is True: - return ir_h.basicpy_bool_constant_op(True).result - elif py_value is False: - return ir_h.basicpy_bool_constant_op(False).result - elif py_value is None: - return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result - elif isinstance(py_value, int): - ir_type = env.target.impl_int_type - ir_attr = ir_c.integer_attr(ir_type, py_value) - return ir_h.constant_op(ir_type, ir_attr).result - elif isinstance(py_value, float): - ir_type = env.target.impl_float_type - ir_attr = ir_c.float_attr(ir_type, py_value) - return ir_h.constant_op(ir_type, ir_attr).result - elif isinstance(py_value, str): - return ir_h.basicpy_str_constant_op(py_value).result - elif isinstance(py_value, bytes): - return ir_h.basicpy_bytes_constant_op(py_value).result - elif isinstance(py_value, type(...)): - return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result - return NotImplemented - - -class ValueCoderChain(ValueCoder): - """Codes values by delegating to sub-coders in order.""" - __slots__ = ["_sub_coders"] - - def __init__(self, sub_coders): - self._sub_coders = sub_coders - - def create_const(self, env: "Environment", py_value): - for sc in self._sub_coders: - result = sc.create_const(env, py_value) - if result is not NotImplemented: - return result - return NotImplemented diff --git a/python/npcomp/compiler/frontend.py b/python/npcomp/compiler/frontend.py index 3cc3974c8..7f830bcaf 100644 --- a/python/npcomp/compiler/frontend.py +++ b/python/npcomp/compiler/frontend.py @@ -14,8 +14,10 @@ from _npcomp.mlir.dialect import ScfDialectHelper from npcomp.dialect import Numpy from . import logging -from .environment import * from .importer import * +from .interfaces import * +from .name_resolver_base import * +from .value_coder_base import * from .target import * __all__ = [ @@ -28,26 +30,20 @@ class ImportFrontend: __slots__ = [ "_ir_context", "_ir_module", - "_helper", - "_target_factory", - "_value_coder", - "_partial_eval_hook", + "_ir_h", + "_config", ] def __init__(self, - ir_context: ir.MLIRContext = None, *, - target_factory: TargetFactory = GenericTarget64, - value_coder: Optional[ValueCoder] = None, - partial_eval_hook: Optional[PartialEvalHook] = None): + config: Configuration, + ir_context: ir.MLIRContext = None): + super().__init__() self._ir_context = ir.MLIRContext() if not ir_context else ir_context self._ir_module = self._ir_context.new_module() - self._helper = AllDialectHelper(self._ir_context, - ir.OpBuilder(self._ir_context)) - self._target_factory = target_factory - self._value_coder = value_coder if value_coder else BuiltinsValueCoder() - self._partial_eval_hook = (partial_eval_hook if partial_eval_hook else - build_default_partial_eval_hook()) + self._ir_h = AllDialectHelper(self._ir_context, + ir.OpBuilder(self._ir_context)) + self._config = config @property def ir_context(self): @@ -59,11 +55,7 @@ class ImportFrontend: @property def ir_h(self): - return self._helper - - @property - def partial_eval_hook(self): - return self._partial_eval_hook + return self._ir_h def import_global_function(self, f): """Imports a global function. @@ -80,7 +72,7 @@ class ImportFrontend: h = self.ir_h ir_c = self.ir_context ir_m = self.ir_module - target = self._target_factory(h) + target = self._config.target_factory(h) filename = inspect.getsourcefile(f) source_lines, start_lineno = inspect.getsourcelines(f) source = "".join(source_lines) @@ -115,26 +107,56 @@ class ImportFrontend: ir_f_type, create_entry_block=True, attrs=attrs) - env = Environment.for_const_global_function( - h, - f, - parameter_bindings=zip(f_params.keys(), ir_f.first_block.args), - value_coder=self._value_coder, - target=target, - partial_eval_hook=self._partial_eval_hook) + env = self._create_const_global_env(f, + parameter_bindings=zip( + f_params.keys(), + ir_f.first_block.args), + target=target) fctx = FunctionContext(ir_c=ir_c, ir_f=ir_f, ir_h=h, filename_ident=filename_ident, - target=target, environment=env) fdimport = FunctionDefImporter(fctx, ast_fd) fdimport.import_body() return ir_f + def _create_const_global_env(self, f, parameter_bindings, target): + """Helper to generate an environment for a global function. + + This is a helper for the very common case and will be wholly insufficient + for advanced cases, including mutable global state, closures, etc. + Globals from the module are considered immutable. + """ + ir_h = self._ir_h + try: + code = f.__code__ + globals_dict = f.__globals__ + builtins_module = globals_dict["__builtins__"] + except AttributeError: + assert False, ( + "Function {} does not have required user-defined function attributes". + format(f)) + + # Locals resolver. + # Note that co_varnames should include both parameter and local names. + locals_resolver = LocalNameResolver(code.co_varnames) + resolvers = ( + locals_resolver, + ConstModuleNameResolver(globals_dict, as_dict=True), + ConstModuleNameResolver(builtins_module), + ) + env = Environment(config=self._config, ir_h=ir_h, name_resolvers=resolvers) + + # Bind parameters. + for name, value in parameter_bindings: + logging.debug("STORE PARAM: {} <- {}", name, value) + locals_resolver.checked_resolve_name(name).store(env, value) + return env + def _resolve_signature_annotation(self, target: Target, annot): - ir_h = self._helper + ir_h = self._ir_h if annot is inspect.Signature.empty: return ir_h.basicpy_UnknownType @@ -164,20 +186,3 @@ class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper): def __init__(self, *args, **kwargs): Numpy.DialectHelper.__init__(self, *args, **kwargs) ScfDialectHelper.__init__(self, *args, **kwargs) - - -def build_default_partial_eval_hook() -> PartialEvalHook: - pe = PartialEvalHook() - ### Modules - pe.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary. - - ### Tuples - # Enable attribute resolution on tuple, which includes namedtuple (which is - # really what we want). - pe.enable_getattr(for_type=tuple) - - ### Temp: resolve a function to a template call for testing - import math - pe.enable_template_call("__global$math.ceil", for_ref=math.ceil) - pe.enable_template_call("__global$math.isclose", for_ref=math.isclose) - return pe diff --git a/python/npcomp/compiler/importer.py b/python/npcomp/compiler/importer.py index 53b46dad5..c5065f0e5 100644 --- a/python/npcomp/compiler/importer.py +++ b/python/npcomp/compiler/importer.py @@ -11,8 +11,7 @@ import traceback from _npcomp.mlir import ir from . import logging -from .environment import * -from .target import * +from .interfaces import * __all__ = [ "FunctionContext", @@ -27,16 +26,14 @@ class FunctionContext: "ir_c", "ir_f", "ir_h", - "target", "filename_ident", "environment", ] - def __init__(self, ir_c, ir_f, ir_h, target, filename_ident, environment): + def __init__(self, ir_c, ir_f, ir_h, filename_ident, environment): self.ir_c = ir_c self.ir_f = ir_f self.ir_h = ir_h - self.target = target self.filename_ident = filename_ident self.environment = environment @@ -68,7 +65,7 @@ class FunctionContext: def lookup_name(self, name) -> NameReference: """Lookup a name in the environment, requiring it to have evaluated.""" - ref = self.environment.lookup(name) + ref = self.environment.resolve_name(name) if ref is None: self.abort("Could not resolve referenced name '{}'".format(name)) logging.debug("Map name({}) -> {}", name, ref) @@ -77,7 +74,7 @@ class FunctionContext: def emit_const_value(self, py_value) -> ir.Value: """Codes a value as a constant, returning an ir Value.""" env = self.environment - result = env.value_coder.create_const(env, py_value) + result = env.code_py_value_as_const(py_value) if result is NotImplemented: self.abort("Cannot code python value as constant: {}".format(py_value)) return result @@ -214,7 +211,7 @@ class ExpressionImporter(BaseNodeVisitor): def emit_constant(self, value): env = self.fctx.environment - ir_const_value = env.value_coder.create_const(env, value) + ir_const_value = env.code_py_value_as_const(value) if ir_const_value is NotImplemented: self.fctx.abort("unknown constant type '%r'" % (value,)) self.value = ir_const_value diff --git a/python/npcomp/compiler/interfaces.py b/python/npcomp/compiler/interfaces.py new file mode 100644 index 000000000..31f83cc83 --- /dev/null +++ b/python/npcomp/compiler/interfaces.py @@ -0,0 +1,295 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Base classes and interfaces.""" + +from collections import namedtuple +from enum import Enum +import sys +from typing import List, Optional, Sequence, Union + +from _npcomp.mlir import ir +from .target import * + +__all__ = [ + "Configuration", + "Environment", + "NameReference", + "NameResolver", + "PartialEvalHook", + "PartialEvalType", + "PartialEvalResult", + "LiveValueRef", + "ValueCoder", + "ValueCoderChain", +] + +_NotImplementedType = type(NotImplemented) + +################################################################################ +# Name resolution +################################################################################ + + +class NameReference: + """Abstract base class for performing operations on a name.""" + __slots__ = [ + "name", + ] + + def __init__(self, name): + super().__init__() + self.name = name + + def load(self, env: "Environment", + ir_h: ir.DialectHelper) -> "PartialEvalResult": + """Loads the IR Value associated with the name. + + The load may either be direct, returning an existing value or + side-effecting, causing a read from an external context. + + Args: + ir_h: The dialect helper used to emit code. + Returns: + A partial evaluation result. + """ + return PartialEvalResult.not_evaluated() + + def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper): + """Stores a new value into the name. + + A subsequent call to 'load' should yield the same value, subject to + typing constraints on value equality. + + Args: + value: The new value to store into the name. + ir_h: The dialect helper used to emit code. + Raises: + NotImplementedError if store is not supported for this name. + """ + raise NotImplementedError() + + +class NameResolver: + """Abstract base class that can resolve a name. + + Name resolvers are typically stacked. + """ + __slots__ = [] + + def checked_resolve_name(self, name: str) -> Optional[NameReference]: + ref = self.resolve_name(name) + assert ref is not None, "Lookup of name {} is required".format(name) + return ref + + def resolve_name(self, name: str) -> Optional[NameReference]: + return None + + +################################################################################ +# Value coding +# Transforms python values into IR values. +################################################################################ + + +class ValueCoder: + """Encodes values in various ways. + + Instances are designed to be daisy-chained and should ignore types that they + don't understand. Functions return NotImplemented if they cannot handle a + case locally. + """ + __slots__ = [] + + def code_py_value_as_const(self, env: "Environment", + py_value) -> Union[_NotImplementedType, ir.Value]: + return NotImplemented + + +class ValueCoderChain(ValueCoder): + """Codes values by delegating to sub-coders in order.""" + __slots__ = ["_sub_coders"] + + def __init__(self, sub_coders: Sequence[ValueCoder]): + self._sub_coders = tuple(sub_coders) + + def code_py_value_as_const(self, env: "Environment", + py_value) -> Union[_NotImplementedType, ir.Value]: + for sc in self._sub_coders: + result = sc.code_py_value_as_const(env, py_value) + if result is not NotImplemented: + return result + return NotImplemented + + +################################################################################ +# Partial evaluation +# When the compiler is extracting from a running program, it is likely that +# evaluations produce live values which can be further partially evaluated +# at import time, in the context of the running instance (versus emitting +# program IR to do so). This behavior is controlled through a PartialEvalHook +# on the environment. +################################################################################ + + +class PartialEvalType(Enum): + # Could not be evaluated immediately and the operation should be + # code-generated. yields NotImplemented. + NOT_EVALUATED = 0 + + # Yields a LiveValueRef + YIELDS_LIVE_VALUE = 1 + + # Yields an IR value + YIELDS_IR_VALUE = 2 + + # Evaluation yielded an error (yields contains exc_info from sys.exc_info()). + ERROR = 3 + + +class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")): + """Encapsulates the result of a partial evaluation.""" + + @staticmethod + def not_evaluated() -> "PartialEvalResult": + return PartialEvalResult(PartialEvalType.NOT_EVALUATED, NotImplemented) + + @staticmethod + def yields_live_value(live_value) -> "PartialEvalResult": + assert isinstance(live_value, LiveValueRef) + return PartialEvalResult(PartialEvalType.YIELDS_LIVE_VALUE, live_value) + + @staticmethod + def yields_ir_value(ir_value: ir.Value) -> "PartialEvalResult": + assert isinstance(ir_value, ir.Value) + return PartialEvalResult(PartialEvalType.YIELDS_IR_VALUE, ir_value) + + @staticmethod + def error() -> "PartialEvalResult": + return PartialEvalResult(PartialEvalType.ERROR, sys.exc_info()) + + @staticmethod + def error_message(message: str) -> "PartialEvalResult": + try: + raise RuntimeError(message) + except RuntimeError: + return PartialEvalResult.error() + + +class LiveValueRef: + """Wraps a live value from the containing environment. + + Typically, when expressions encounter a live value, a limited number of + partial evaluations can be done against it in place (versus emitting the code + to import it and perform the operation). This default base class will not + perform any static evaluations. + """ + __slots__ = [ + "live_value", + ] + + def __init__(self, live_value): + super().__init__() + self.live_value = live_value + + def resolve_getattr(self, env: "Environment", + attr_name: str) -> PartialEvalResult: + """Gets a named attribute from the live value.""" + return PartialEvalResult.not_evaluated() + + def resolve_call(self, env: "Environment", args, + keywords: Sequence[str]) -> PartialEvalResult: + """Resolves a function call given 'args' and 'keywords'.""" + return PartialEvalResult.not_evaluated() + + def __repr__(self): + return "MacroValueRef({}, {})".format(self.__class__.__name__, + self.live_value) + + +class PartialEvalHook: + """Hook interface for performing partial evaluation.""" + __slots__ = [] + + def partial_evaluate(self, py_value) -> PartialEvalResult: + raise NotImplementedError + + +################################################################################ +# Configuration and environment +################################################################################ + + +class Configuration: + """Base class providing global configuration objects.""" + __slots__ = [ + "target_factory", + "base_name_resolvers", + "value_coder", + "partial_eval_hook", + ] + + def __init__(self, + *, + target_factory: TargetFactory, + base_name_resolvers: Sequence[NameResolver] = (), + value_coder: Optional[ValueCoder] = None, + partial_eval_hook: PartialEvalHook = None): + super().__init__() + self.target_factory = target_factory + self.base_name_resolvers = tuple(base_name_resolvers) + self.value_coder = value_coder if value_coder else ValueCoderChain(()) + self.partial_eval_hook = partial_eval_hook + + def __repr__(self): + return ("Configuration(target_factory={}, base_name_resolvers={}, " + "value_code={}, partial_eval_hook={})").format( + self.target_factory, self.base_name_resolvers, self.value_coder, + self.partial_eval_hook) + + +class Environment: + """Instantiated configuration for emitting code in a specific context. + + This brings together: + - The code generation context (ir_h) + - An instantiated target + - Delegating interfaces for other configuration objects. + + Note that this class does not actually implement most of the delegate + interfaces because it hides the fact that some may require more obtuse + APIs than should be exposed to end callers (i.e. expecting environment or + other config objects). + """ + __slots__ = [ + "config", + "ir_h", + "_name_resolvers", + "target", + ] + + def __init__(self, + *, + config: Configuration, + ir_h: ir.DialectHelper, + name_resolvers: Sequence[NameResolver] = ()): + super().__init__() + self.config = config + self.ir_h = ir_h + self.target = config.target_factory(self.ir_h) + self._name_resolvers = (tuple(name_resolvers) + + self.config.base_name_resolvers) + + def resolve_name(self, name: str) -> Optional[NameReference]: + for resolver in self._name_resolvers: + ref = resolver.resolve_name(name) + if ref is not None: + return ref + return None + + def partial_evaluate(self, py_value) -> PartialEvalResult: + return self.config.partial_eval_hook.partial_evaluate(py_value) + + def code_py_value_as_const(self, + py_value) -> Union[_NotImplementedType, ir.Value]: + return self.config.value_coder.code_py_value_as_const(self, py_value) diff --git a/python/npcomp/compiler/name_resolver_base.py b/python/npcomp/compiler/name_resolver_base.py new file mode 100644 index 000000000..8986ed82e --- /dev/null +++ b/python/npcomp/compiler/name_resolver_base.py @@ -0,0 +1,114 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Name resolvers for common scenarios.""" + +from typing import Optional + +from _npcomp.mlir import ir + +from .interfaces import * + +__all__ = [ + "ConstModuleNameResolver", + "LocalNameResolver", +] + +################################################################################ +# Local name resolution +# This is used for local names that can be managed purely as SSA values. +################################################################################ + + +class LocalNameReference(NameReference): + """Holds an association between a name and SSA value.""" + __slots__ = [ + "_current_value", + ] + + def __init__(self, name, initial_value=None): + super().__init__(name) + self._current_value = initial_value + + def load(self, env: Environment) -> PartialEvalResult: + if self._current_value is None: + return PartialEvalResult.error_message( + "Attempt to access local '{}' before assignment".format(self.name)) + return PartialEvalResult.yields_ir_value(self._current_value) + + def store(self, env: Environment, value: ir.Value): + self._current_value = value + + def __repr__(self): + return "".format(self.name) + + +class LocalNameResolver(NameResolver): + """Resolves names in a local cache of SSA values. + + This is used to manage locals and arguments (that are not referenced through + a closure). + """ + __slots__ = [ + "_name_refs", + ] + + def __init__(self, names): + super().__init__() + self._name_refs = {name: LocalNameReference(name) for name in names} + + def resolve_name(self, name) -> Optional[NameReference]: + return self._name_refs.get(name) + + +################################################################################ +# Constant name resolution +# For some DSLs, it can be appropriate to treat some containing scopes as +# constants. This strategy typically binds to a module and routes loads +# through the partial evaluation hook. +################################################################################ + + +class ConstNameReference(NameReference): + """Represents a name/value mapping that will emit as a constant.""" + __slots__ = [ + "_py_value", + ] + + def __init__(self, name, py_value): + super().__init__(name) + self._py_value = py_value + + def load(self, env: Environment) -> PartialEvalResult: + return env.partial_evaluate(self._py_value) + + def __repr__(self): + return "".format(self.name, self._py_value) + + +class ConstModuleNameResolver(NameResolver): + """Resolves names from a module by treating them as immutable and loading + them as constants into a function scope. + """ + __slots__ = [ + "_as_dict", + "module", + ] + + def __init__(self, module, *, as_dict=False): + super().__init__() + self.module = module + self._as_dict = as_dict + + def resolve_name(self, name) -> Optional[NameReference]: + if self._as_dict: + if name in self.module: + py_value = self.module[name] + else: + return None + else: + try: + py_value = getattr(self.module, name) + except AttributeError: + return None + return ConstNameReference(name, py_value) diff --git a/python/npcomp/compiler/partial_eval_base.py b/python/npcomp/compiler/partial_eval_base.py new file mode 100644 index 000000000..6f3e42db0 --- /dev/null +++ b/python/npcomp/compiler/partial_eval_base.py @@ -0,0 +1,134 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Partial evaluation helpers and support for built-in and common scenarios.""" + +from .interfaces import * +from .py_value_utils import * +from . import logging + +__all__ = [ + "MappedPartialEvalHook", + "ResolveAttrLiveValueRef", + "TemplateCallLiveValueRef", +] + +_Unspec = object() + +################################################################################ +# LiveValueRef specializations for various kinds of access +################################################################################ + + +class ResolveAttrLiveValueRef(LiveValueRef): + """Custom LiveValueRef that will resolve attributes via getattr.""" + __slots__ = [] + + def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult: + logging.debug("RESOLVE_GETATTR '{}' on {}".format(attr_name, + self.live_value)) + try: + attr_py_value = getattr(self.live_value, attr_name) + except: + return PartialEvalResult.error() + return env.partial_evaluate(attr_py_value) + + +class TemplateCallLiveValueRef(LiveValueRef): + """Custom LiveValueRef that resolves calls to a func_template_call op.""" + __slots__ = ["callee_name"] + + def __init__(self, callee_name, live_value): + super().__init__(live_value) + self.callee_name = callee_name + + def resolve_call(self, env: "Environment", args, + keywords) -> PartialEvalResult: + linear_args = list(args) + kw_arg_names = [] + for kw_name, kw_value in keywords: + kw_arg_names.append(kw_name) + linear_args.append(kw_value) + + ir_h = env.ir_h + result_ir_value = ir_h.basicpy_func_template_call_op( + result_type=ir_h.basicpy_UnknownType, + callee_symbol=self.callee_name, + args=linear_args, + arg_names=kw_arg_names).result + return PartialEvalResult.yields_ir_value(result_ir_value) + + +################################################################################ +# PartialEvalHook implementations +################################################################################ + + +class MappedPartialEvalHook(PartialEvalHook): + """A PartialEvalHook that maps rules to produce live values. + + Internally, this implementation binds a predicate to an action. The predicate + can be: + - A python value matched by reference or value equality + - A type that a value must be an instance of + - An arbitrary lambda (should be limited to special cases as it forces + a linear scan). + + An action can be one of + - A `lambda python_value: PartialEvalResult...` + - A PartialEvalResult to directly return + - None to indicate that the python value should be processed directly + """ + __slots__ = [ + "_value_map", + ] + + def __init__(self): + super().__init__() + self._value_map = PyValueMap() + + def partial_evaluate(self, py_value) -> PartialEvalResult: + """Performs partial evaluation on a python value.""" + binding = self._value_map.lookup(py_value) + if binding is None: + logging.debug("PARTIAL EVAL RESOLVE {}: Passthrough", py_value) + return PartialEvalResult.yields_live_value(LiveValueRef(py_value)) + if isinstance(binding, PartialEvalResult): + return binding + # Attempt to call. + try: + binding = binding(py_value) + assert isinstance(binding, PartialEvalResult), ( + "Expected PartialEvalResult but got {}".format(binding)) + logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding) + return binding + except: + return PartialEvalResult.error() + + def _bind_action(self, + action, + *, + for_ref=_Unspec, + for_type=_Unspec, + for_predicate=_Unspec): + if for_ref is not _Unspec: + self._value_map.bind_reference(for_ref, action) + elif for_type is not _Unspec: + self._value_map.bind_type(for_type, action) + elif for_predicate is not _Unspec: + self._value_map.bind_predicate(for_predicate, action) + else: + raise ValueError( + "Must specify one of 'for_ref', 'for_type' or 'for_predicate") + + def enable_getattr(self, **kwargs): + """Enables partial evaluation of getattr.""" + self._bind_action( + lambda pv: PartialEvalResult.yields_live_value( + ResolveAttrLiveValueRef(pv)), **kwargs) + + def enable_template_call(self, callee_name, **kwargs): + """"Enables a global template call.""" + self._bind_action( + lambda pv: PartialEvalResult.yields_live_value( + TemplateCallLiveValueRef(callee_name, pv)), **kwargs) diff --git a/python/npcomp/compiler/test_config.py b/python/npcomp/compiler/test_config.py new file mode 100644 index 000000000..4c4d575b0 --- /dev/null +++ b/python/npcomp/compiler/test_config.py @@ -0,0 +1,55 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Various configuration helpers for testing.""" + +import ast + +from . import logging +from .frontend import * +from .interfaces import * +from .partial_eval_base import * +from .target import * +from .value_coder_base import * + + +def create_import_dump_decorator(*, + target_factory: TargetFactory = GenericTarget64 + ): + config = create_test_config(target_factory=target_factory) + logging.debug("Testing with config: {}", config) + + def decorator(f): + fe = ImportFrontend(config=config) + fe.import_global_function(f) + print("// -----") + print(fe.ir_module.to_asm()) + return f + + return decorator + + +def create_test_config(target_factory: TargetFactory = GenericTarget64): + value_coder = BuiltinsValueCoder() + pe_hook = build_default_partial_eval_hook() + + return Configuration(target_factory=target_factory, + value_coder=value_coder, + partial_eval_hook=pe_hook) + + +def build_default_partial_eval_hook() -> PartialEvalHook: + pe = MappedPartialEvalHook() + ### Modules + pe.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary. + + ### Tuples + # Enable attribute resolution on tuple, which includes namedtuple (which is + # really what we want). + pe.enable_getattr(for_type=tuple) + + ### Temp: resolve a function to a template call for testing + import math + pe.enable_template_call("__global$math.ceil", for_ref=math.ceil) + pe.enable_template_call("__global$math.isclose", for_ref=math.isclose) + return pe diff --git a/python/npcomp/compiler/value_coder_base.py b/python/npcomp/compiler/value_coder_base.py new file mode 100644 index 000000000..825d7f7aa --- /dev/null +++ b/python/npcomp/compiler/value_coder_base.py @@ -0,0 +1,47 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Value coders for built-in and common scenarios.""" + +from typing import Union + +from _npcomp.mlir import ir + +from .interfaces import * + +__all__ = [ + "BuiltinsValueCoder", +] + +_NotImplementedType = type(NotImplemented) + + +class BuiltinsValueCoder(ValueCoder): + """Value coder for builtin python types.""" + __slots__ = [] + + def code_py_value_as_const(self, env: Environment, + py_value) -> Union[_NotImplementedType, ir.Value]: + ir_h = env.ir_h + ir_c = ir_h.context + if py_value is True: + return ir_h.basicpy_bool_constant_op(True).result + elif py_value is False: + return ir_h.basicpy_bool_constant_op(False).result + elif py_value is None: + return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result + elif isinstance(py_value, int): + ir_type = env.target.impl_int_type + ir_attr = ir_c.integer_attr(ir_type, py_value) + return ir_h.constant_op(ir_type, ir_attr).result + elif isinstance(py_value, float): + ir_type = env.target.impl_float_type + ir_attr = ir_c.float_attr(ir_type, py_value) + return ir_h.constant_op(ir_type, ir_attr).result + elif isinstance(py_value, str): + return ir_h.basicpy_str_constant_op(py_value).result + elif isinstance(py_value, bytes): + return ir_h.basicpy_bytes_constant_op(py_value).result + elif isinstance(py_value, type(...)): + return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result + return NotImplemented