# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from collections import namedtuple from enum import Enum import inspect import sys from typing import Optional, Union from _npcomp.mlir import ir from . import logging from .py_value_utils import * from .target import * __all__ = [ "BuiltinsValueCoder", "Environment", "LiveValueRef", "NameReference", "NameResolver", "PartialEvalResult", "PartialEvalType", "PartialEvalHook", "ResolveAttrLiveValueRef", "ValueCoder", "ValueCoderChain", ] _Unspec = object() ################################################################################ # Interfaces and base classes ################################################################################ class ValueCoder: """Encodes values in various ways. Instances are designed to be daisy-chained and should ignore types that they don't understand. Functions return NotImplemented if they cannot handle a case locally. """ __slots__ = [] def create_const(self, env: "Environment", py_value): return NotImplemented class NameReference: """Abstract base class for performing operations on a name.""" __slots__ = [ "name", ] def __init__(self, name): super().__init__() self.name = name def load(self, env: "Environment", ir_h: ir.DialectHelper) -> "PartialEvalResult": """Loads the IR Value associated with the name. The load may either be direct, returning an existing value or side-effecting, causing a read from an external context. Args: ir_h: The dialect helper used to emit code. Returns: A partial evaluation result. """ return PartialEvalResult.not_evaluated() def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper): """Stores a new value into the name. A subsequent call to 'load' should yield the same value, subject to typing constraints on value equality. Args: value: The new value to store into the name. ir_h: The dialect helper used to emit code. Raises: NotImplementedError if store is not supported for this name. """ raise NotImplementedError() class NameResolver: """Abstract base class that can resolve a name. Name resolvers are typically stacked. """ def checked_lookup(self, name): ref = self.lookup(name) assert ref is not None, "Lookup of name {} is required".format(name) return ref def lookup(self, name) -> Optional[NameReference]: return None ################################################################################ # Partial evaluation # When the compiler is extracting from a running program, it is likely that # evaluations produce live values which can be further partially evaluated # at import time, in the context of the running instance (versus emitting # program IR to do so). This behavior is controlled through a PartialEvalHook # on the environment. ################################################################################ class PartialEvalType(Enum): # Could not be evaluated immediately and the operation should be # code-generated. yields NotImplemented. NOT_EVALUATED = 0 # Yields a LiveValueRef YIELDS_LIVE_VALUE = 1 # Yields an IR value YIELDS_IR_VALUE = 2 # Evaluation yielded an error (yields contains exc_info from sys.exc_info()). ERROR = 3 class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")): """Encapsulates the result of a partial evaluation.""" @classmethod def not_evaluated(cls): return cls(PartialEvalType.NOT_EVALUATED, NotImplemented) @classmethod def yields_live_value(cls, live_value): assert isinstance(live_value, LiveValueRef) return cls(PartialEvalType.YIELDS_LIVE_VALUE, live_value) @classmethod def yields_ir_value(cls, ir_value): assert isinstance(ir_value, ir.Value) return cls(PartialEvalType.YIELDS_IR_VALUE, ir_value) @classmethod def error(cls): return cls(PartialEvalType.ERROR, sys.exc_info()) @classmethod def error_message(cls, message): try: raise RuntimeError(message) except RuntimeError: return cls.error() class LiveValueRef: """Wraps a live value from the containing environment. Typically, when expressions encounter a live value, a limited number of partial evaluations can be done against it in place (versus emitting the code to import it and perform the operation). This default base class will not perform any static evaluations. """ __slots__ = [ "live_value", ] def __init__(self, live_value): super().__init__() self.live_value = live_value def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult: """Gets a named attribute from the live value.""" return PartialEvalResult.not_evaluated() def resolve_call(self, env: "Environment", args, keywords) -> PartialEvalResult: """Resolves a function call given 'args' and 'keywords'.""" return PartialEvalResult.not_evaluated() def __repr__(self): return "MacroValueRef({}, {})".format(self.__class__.__name__, self.live_value) class ResolveAttrLiveValueRef(LiveValueRef): """Custom MacroValueRef that will resolve attributes via getattr.""" __slots__ = [] def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult: logging.debug("RESOLVE_GETATTR '{}' on {}".format(attr_name, self.live_value)) try: attr_py_value = getattr(self.live_value, attr_name) except: return PartialEvalResult.error() return env.partial_eval_hook.resolve(attr_py_value) class TemplateCallLiveValueRef(LiveValueRef): """Custom LiveValueRef that resolves calls to a func_template_call op.""" __slots__ = ["callee_name"] def __init__(self, callee_name, live_value): super().__init__(live_value) self.callee_name = callee_name def resolve_call(self, env: "Environment", args, keywords) -> PartialEvalResult: linear_args = list(args) kw_arg_names = [] for kw_name, kw_value in keywords: kw_arg_names.append(kw_name) linear_args.append(kw_value) ir_h = env.ir_h result_ir_value = ir_h.basicpy_func_template_call_op( result_type=ir_h.basicpy_UnknownType, callee_symbol=self.callee_name, args=linear_args, arg_names=kw_arg_names).result return PartialEvalResult.yields_ir_value(result_ir_value) class PartialEvalHook: """Owned by an environment to customize partial evaluation.""" __slots__ = [ "_value_map", ] def __init__(self): super().__init__() self._value_map = PyValueMap() def resolve(self, py_value) -> PartialEvalResult: """Performs partial evaluation on a python value.""" binding = self._value_map.lookup(py_value) if binding is None: logging.debug("PARTIAL EVAL RESOLVE {}: Passthrough", py_value) return PartialEvalResult.yields_live_value(LiveValueRef(py_value)) if isinstance(binding, LiveValueRef): logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding) return PartialEvalResult.yields_live_value(binding) if isinstance(binding, PartialEvalResult): return binding # Attempt to call. try: binding = binding(py_value) assert isinstance(binding, PartialEvalResult), ( "Expected PartialEvalResult but got {}".format(binding)) logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding) return binding except: return PartialEvalResult.error() def _bind(self, binding, *, for_ref=_Unspec, for_type=_Unspec, for_predicate=_Unspec): if for_ref is not _Unspec: self._value_map.bind_reference(for_ref, binding) elif for_type is not _Unspec: self._value_map.bind_type(for_type, binding) elif for_predicate is not _Unspec: self._value_map.bind_predicate(for_predicate, binding) else: raise ValueError( "Must specify one of 'for_ref', 'for_type' or 'for_predicate") def enable_getattr(self, **kwargs): """Enables partial evaluation of getattr.""" self._bind( lambda pv: PartialEvalResult.yields_live_value( ResolveAttrLiveValueRef(pv)), **kwargs) def enable_template_call(self, callee_name, **kwargs): """"Enables a global template call.""" self._bind( lambda pv: PartialEvalResult.yields_live_value( TemplateCallLiveValueRef(callee_name, pv)), **kwargs) ################################################################################ # Environment # Top level instance encapsulating access to runtime state. ################################################################################ class Environment(NameResolver): """Manages access to the environment of a code region. This encapsulates name lookup, access to the containing module, etc. """ __slots__ = [ "ir_h", "_name_resolvers", "target", "value_coder", "partial_eval_hook", ] def __init__(self, ir_h: ir.DialectHelper, *, target: Target, name_resolvers=(), value_coder, partial_eval_hook=None): super().__init__() self.ir_h = ir_h self.target = target self._name_resolvers = name_resolvers self.value_coder = value_coder self.partial_eval_hook = partial_eval_hook if partial_eval_hook else PartialEvalHook( ) @classmethod def for_const_global_function(cls, ir_h: ir.DialectHelper, f, *, parameter_bindings, **kwargs): """Helper to generate an environment for a global function. This is a helper for the very common case and will be wholly insufficient for advanced cases, including mutable global state, closures, etc. Globals from the module are considered immutable. """ try: code = f.__code__ globals_dict = f.__globals__ builtins_module = globals_dict["__builtins__"] except AttributeError: assert False, ( "Function {} does not have required user-defined function attributes". format(f)) # Locals resolver. # Note that co_varnames should include both parameter and local names. locals_resolver = LocalNameResolver(code.co_varnames) resolvers = ( locals_resolver, ConstModuleNameResolver(globals_dict, as_dict=True), ConstModuleNameResolver(builtins_module), ) env = cls(ir_h, name_resolvers=resolvers, **kwargs) # Bind parameters. for name, value in parameter_bindings: logging.debug("STORE PARAM: {} <- {}", name, value) locals_resolver.checked_lookup(name).store(env, value) return env def lookup(self, name) -> Optional[NameReference]: for resolver in self._name_resolvers: ref = resolver.lookup(name) if ref is not None: return ref return None ################################################################################ # Standard name resolvers ################################################################################ class LocalNameReference(NameReference): """Holds an association between a name and SSA value.""" __slots__ = [ "_current_value", ] def __init__(self, name, initial_value=None): super().__init__(name) self._current_value = initial_value def load(self, env: "Environment") -> PartialEvalResult: if self._current_value is None: return PartialEvalResult.error_message( "Attempt to access local '{}' before assignment".format(self.name)) return PartialEvalResult.yields_ir_value(self._current_value) def store(self, env: "Environment", value: ir.Value): self._current_value = value def __repr__(self): return "".format(self.name) class LocalNameResolver(NameResolver): """Resolves names in a local cache of SSA values. This is used to manage locals and arguments (that are not referenced through a closure). """ __slots__ = [ "_name_refs", ] def __init__(self, names): super().__init__() self._name_refs = {name: LocalNameReference(name) for name in names} def lookup(self, name) -> Optional[NameReference]: return self._name_refs.get(name) class ConstNameReference(NameReference): """Represents a name/value mapping that will emit as a constant.""" __slots__ = [ "_py_value", ] def __init__(self, name, py_value): super().__init__(name) self._py_value = py_value def load(self, env: "Environment") -> PartialEvalResult: return env.partial_eval_hook.resolve(self._py_value) def __repr__(self): return "".format(self.name, self._py_value) class ConstModuleNameResolver(NameResolver): """Resolves names from a module by treating them as immutable and loading them as constants into a function scope. """ __slots__ = [ "_as_dict", "module", ] def __init__(self, module, *, as_dict=False): super().__init__() self.module = module self._as_dict = as_dict def lookup(self, name) -> Optional[NameReference]: if self._as_dict: if name in self.module: py_value = self.module[name] else: return None else: try: py_value = getattr(self.module, name) except AttributeError: return None return ConstNameReference(name, py_value) ################################################################################ # Standard value coders ################################################################################ class BuiltinsValueCoder: """Value coder for builtin python types.""" __slots__ = [] def create_const(self, env: "Environment", py_value): ir_h = env.ir_h ir_c = ir_h.context if py_value is True: return ir_h.basicpy_bool_constant_op(True).result elif py_value is False: return ir_h.basicpy_bool_constant_op(False).result elif py_value is None: return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result elif isinstance(py_value, int): ir_type = env.target.impl_int_type ir_attr = ir_c.integer_attr(ir_type, py_value) return ir_h.constant_op(ir_type, ir_attr).result elif isinstance(py_value, float): ir_type = env.target.impl_float_type ir_attr = ir_c.float_attr(ir_type, py_value) return ir_h.constant_op(ir_type, ir_attr).result elif isinstance(py_value, str): return ir_h.basicpy_str_constant_op(py_value).result elif isinstance(py_value, bytes): return ir_h.basicpy_bytes_constant_op(py_value).result elif isinstance(py_value, type(...)): return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result return NotImplemented class ValueCoderChain(ValueCoder): """Codes values by delegating to sub-coders in order.""" __slots__ = ["_sub_coders"] def __init__(self, sub_coders): self._sub_coders = sub_coders def create_const(self, env: "Environment", py_value): for sc in self._sub_coders: result = sc.create_const(env, py_value) if result is not NotImplemented: return result return NotImplemented