mirror of https://github.com/llvm/torch-mlir
Refactor environment.py into components.
* Creates a new top level Configuration class * Adds a module for creating test configs, getting some hard coding out of core classespull/1/head
parent
d6b428fb60
commit
bccfd5f6fc
|
@ -3,6 +3,7 @@
|
|||
from npcomp.compiler.backend import iree
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import logging
|
||||
from npcomp.compiler import test_config
|
||||
from npcomp.compiler.target import *
|
||||
|
||||
# TODO: This should all exist in a high level API somewhere.
|
||||
|
@ -13,7 +14,8 @@ logging.enable()
|
|||
|
||||
|
||||
def compile_function(f):
|
||||
fe = ImportFrontend(target_factory=GenericTarget32)
|
||||
fe = ImportFrontend(config=test_config.create_test_config(
|
||||
target_factory=GenericTarget32))
|
||||
fe.import_global_function(f)
|
||||
compiler = iree.CompilerBackend()
|
||||
vm_blob = compiler.compile(fe.ir_module)
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# Full checking for add. Others just check validity.
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @logical_and
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @binary_lt_
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @integer_constants
|
||||
|
|
|
@ -2,16 +2,11 @@
|
|||
|
||||
# Subset of constant tests which verify against a GenericTarget32.
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
from npcomp.compiler.target import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend(target_factory=GenericTarget32)
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator(
|
||||
target_factory=GenericTarget32)
|
||||
|
||||
|
||||
# CHECK-LABEL: func @integer_constants
|
||||
|
|
|
@ -2,15 +2,9 @@
|
|||
|
||||
import collections
|
||||
import math
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @module_constant
|
||||
|
|
|
@ -1,15 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-type-inference -convert-basicpy-to-std | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
################################################################################
|
||||
# Integer tests
|
||||
|
|
|
@ -1,16 +1,9 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
"""Module docstring."""
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
OUTER_ONE = 1
|
||||
OUTER_STRING = "Hello"
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @positional_args
|
||||
|
|
|
@ -1,15 +1,9 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
import math
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @call_ceil_positional
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-type-inference | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print(fe.ir_module.to_asm())
|
||||
print("// -----")
|
||||
return f
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @arithmetic_expression
|
||||
|
|
|
@ -1,502 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Optional, Union
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
from . import logging
|
||||
from .py_value_utils import *
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
"BuiltinsValueCoder",
|
||||
"Environment",
|
||||
"LiveValueRef",
|
||||
"NameReference",
|
||||
"NameResolver",
|
||||
"PartialEvalResult",
|
||||
"PartialEvalType",
|
||||
"PartialEvalHook",
|
||||
"ResolveAttrLiveValueRef",
|
||||
"ValueCoder",
|
||||
"ValueCoderChain",
|
||||
]
|
||||
|
||||
_Unspec = object()
|
||||
|
||||
################################################################################
|
||||
# Interfaces and base classes
|
||||
################################################################################
|
||||
|
||||
|
||||
class ValueCoder:
|
||||
"""Encodes values in various ways.
|
||||
|
||||
Instances are designed to be daisy-chained and should ignore types that they
|
||||
don't understand. Functions return NotImplemented if they cannot handle a
|
||||
case locally.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
||||
def create_const(self, env: "Environment", py_value):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NameReference:
|
||||
"""Abstract base class for performing operations on a name."""
|
||||
__slots__ = [
|
||||
"name",
|
||||
]
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def load(self, env: "Environment",
|
||||
ir_h: ir.DialectHelper) -> "PartialEvalResult":
|
||||
"""Loads the IR Value associated with the name.
|
||||
|
||||
The load may either be direct, returning an existing value or
|
||||
side-effecting, causing a read from an external context.
|
||||
|
||||
Args:
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Returns:
|
||||
A partial evaluation result.
|
||||
"""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
|
||||
"""Stores a new value into the name.
|
||||
|
||||
A subsequent call to 'load' should yield the same value, subject to
|
||||
typing constraints on value equality.
|
||||
|
||||
Args:
|
||||
value: The new value to store into the name.
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Raises:
|
||||
NotImplementedError if store is not supported for this name.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NameResolver:
|
||||
"""Abstract base class that can resolve a name.
|
||||
|
||||
Name resolvers are typically stacked.
|
||||
"""
|
||||
|
||||
def checked_lookup(self, name):
|
||||
ref = self.lookup(name)
|
||||
assert ref is not None, "Lookup of name {} is required".format(name)
|
||||
return ref
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Partial evaluation
|
||||
# When the compiler is extracting from a running program, it is likely that
|
||||
# evaluations produce live values which can be further partially evaluated
|
||||
# at import time, in the context of the running instance (versus emitting
|
||||
# program IR to do so). This behavior is controlled through a PartialEvalHook
|
||||
# on the environment.
|
||||
################################################################################
|
||||
|
||||
|
||||
class PartialEvalType(Enum):
|
||||
# Could not be evaluated immediately and the operation should be
|
||||
# code-generated. yields NotImplemented.
|
||||
NOT_EVALUATED = 0
|
||||
|
||||
# Yields a LiveValueRef
|
||||
YIELDS_LIVE_VALUE = 1
|
||||
|
||||
# Yields an IR value
|
||||
YIELDS_IR_VALUE = 2
|
||||
|
||||
# Evaluation yielded an error (yields contains exc_info from sys.exc_info()).
|
||||
ERROR = 3
|
||||
|
||||
|
||||
class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
|
||||
"""Encapsulates the result of a partial evaluation."""
|
||||
|
||||
@classmethod
|
||||
def not_evaluated(cls):
|
||||
return cls(PartialEvalType.NOT_EVALUATED, NotImplemented)
|
||||
|
||||
@classmethod
|
||||
def yields_live_value(cls, live_value):
|
||||
assert isinstance(live_value, LiveValueRef)
|
||||
return cls(PartialEvalType.YIELDS_LIVE_VALUE, live_value)
|
||||
|
||||
@classmethod
|
||||
def yields_ir_value(cls, ir_value):
|
||||
assert isinstance(ir_value, ir.Value)
|
||||
return cls(PartialEvalType.YIELDS_IR_VALUE, ir_value)
|
||||
|
||||
@classmethod
|
||||
def error(cls):
|
||||
return cls(PartialEvalType.ERROR, sys.exc_info())
|
||||
|
||||
@classmethod
|
||||
def error_message(cls, message):
|
||||
try:
|
||||
raise RuntimeError(message)
|
||||
except RuntimeError:
|
||||
return cls.error()
|
||||
|
||||
|
||||
class LiveValueRef:
|
||||
"""Wraps a live value from the containing environment.
|
||||
|
||||
Typically, when expressions encounter a live value, a limited number of
|
||||
partial evaluations can be done against it in place (versus emitting the code
|
||||
to import it and perform the operation). This default base class will not
|
||||
perform any static evaluations.
|
||||
"""
|
||||
__slots__ = [
|
||||
"live_value",
|
||||
]
|
||||
|
||||
def __init__(self, live_value):
|
||||
super().__init__()
|
||||
self.live_value = live_value
|
||||
|
||||
def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult:
|
||||
"""Gets a named attribute from the live value."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
keywords) -> PartialEvalResult:
|
||||
"""Resolves a function call given 'args' and 'keywords'."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def __repr__(self):
|
||||
return "MacroValueRef({}, {})".format(self.__class__.__name__,
|
||||
self.live_value)
|
||||
|
||||
|
||||
class ResolveAttrLiveValueRef(LiveValueRef):
|
||||
"""Custom MacroValueRef that will resolve attributes via getattr."""
|
||||
__slots__ = []
|
||||
|
||||
def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult:
|
||||
logging.debug("RESOLVE_GETATTR '{}' on {}".format(attr_name,
|
||||
self.live_value))
|
||||
try:
|
||||
attr_py_value = getattr(self.live_value, attr_name)
|
||||
except:
|
||||
return PartialEvalResult.error()
|
||||
return env.partial_eval_hook.resolve(attr_py_value)
|
||||
|
||||
|
||||
class TemplateCallLiveValueRef(LiveValueRef):
|
||||
"""Custom LiveValueRef that resolves calls to a func_template_call op."""
|
||||
__slots__ = ["callee_name"]
|
||||
|
||||
def __init__(self, callee_name, live_value):
|
||||
super().__init__(live_value)
|
||||
self.callee_name = callee_name
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
keywords) -> PartialEvalResult:
|
||||
linear_args = list(args)
|
||||
kw_arg_names = []
|
||||
for kw_name, kw_value in keywords:
|
||||
kw_arg_names.append(kw_name)
|
||||
linear_args.append(kw_value)
|
||||
|
||||
ir_h = env.ir_h
|
||||
result_ir_value = ir_h.basicpy_func_template_call_op(
|
||||
result_type=ir_h.basicpy_UnknownType,
|
||||
callee_symbol=self.callee_name,
|
||||
args=linear_args,
|
||||
arg_names=kw_arg_names).result
|
||||
return PartialEvalResult.yields_ir_value(result_ir_value)
|
||||
|
||||
|
||||
class PartialEvalHook:
|
||||
"""Owned by an environment to customize partial evaluation."""
|
||||
__slots__ = [
|
||||
"_value_map",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._value_map = PyValueMap()
|
||||
|
||||
def resolve(self, py_value) -> PartialEvalResult:
|
||||
"""Performs partial evaluation on a python value."""
|
||||
binding = self._value_map.lookup(py_value)
|
||||
if binding is None:
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: Passthrough", py_value)
|
||||
return PartialEvalResult.yields_live_value(LiveValueRef(py_value))
|
||||
if isinstance(binding, LiveValueRef):
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding)
|
||||
return PartialEvalResult.yields_live_value(binding)
|
||||
if isinstance(binding, PartialEvalResult):
|
||||
return binding
|
||||
# Attempt to call.
|
||||
try:
|
||||
binding = binding(py_value)
|
||||
assert isinstance(binding, PartialEvalResult), (
|
||||
"Expected PartialEvalResult but got {}".format(binding))
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding)
|
||||
return binding
|
||||
except:
|
||||
return PartialEvalResult.error()
|
||||
|
||||
def _bind(self,
|
||||
binding,
|
||||
*,
|
||||
for_ref=_Unspec,
|
||||
for_type=_Unspec,
|
||||
for_predicate=_Unspec):
|
||||
if for_ref is not _Unspec:
|
||||
self._value_map.bind_reference(for_ref, binding)
|
||||
elif for_type is not _Unspec:
|
||||
self._value_map.bind_type(for_type, binding)
|
||||
elif for_predicate is not _Unspec:
|
||||
self._value_map.bind_predicate(for_predicate, binding)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must specify one of 'for_ref', 'for_type' or 'for_predicate")
|
||||
|
||||
def enable_getattr(self, **kwargs):
|
||||
"""Enables partial evaluation of getattr."""
|
||||
self._bind(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
ResolveAttrLiveValueRef(pv)), **kwargs)
|
||||
|
||||
def enable_template_call(self, callee_name, **kwargs):
|
||||
""""Enables a global template call."""
|
||||
self._bind(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
TemplateCallLiveValueRef(callee_name, pv)), **kwargs)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Environment
|
||||
# Top level instance encapsulating access to runtime state.
|
||||
################################################################################
|
||||
|
||||
|
||||
class Environment(NameResolver):
|
||||
"""Manages access to the environment of a code region.
|
||||
|
||||
This encapsulates name lookup, access to the containing module, etc.
|
||||
"""
|
||||
__slots__ = [
|
||||
"ir_h",
|
||||
"_name_resolvers",
|
||||
"target",
|
||||
"value_coder",
|
||||
"partial_eval_hook",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
ir_h: ir.DialectHelper,
|
||||
*,
|
||||
target: Target,
|
||||
name_resolvers=(),
|
||||
value_coder,
|
||||
partial_eval_hook=None):
|
||||
super().__init__()
|
||||
self.ir_h = ir_h
|
||||
self.target = target
|
||||
self._name_resolvers = name_resolvers
|
||||
self.value_coder = value_coder
|
||||
self.partial_eval_hook = partial_eval_hook if partial_eval_hook else PartialEvalHook(
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_const_global_function(cls, ir_h: ir.DialectHelper, f, *,
|
||||
parameter_bindings, **kwargs):
|
||||
"""Helper to generate an environment for a global function.
|
||||
|
||||
This is a helper for the very common case and will be wholly insufficient
|
||||
for advanced cases, including mutable global state, closures, etc.
|
||||
Globals from the module are considered immutable.
|
||||
"""
|
||||
try:
|
||||
code = f.__code__
|
||||
globals_dict = f.__globals__
|
||||
builtins_module = globals_dict["__builtins__"]
|
||||
except AttributeError:
|
||||
assert False, (
|
||||
"Function {} does not have required user-defined function attributes".
|
||||
format(f))
|
||||
|
||||
# Locals resolver.
|
||||
# Note that co_varnames should include both parameter and local names.
|
||||
locals_resolver = LocalNameResolver(code.co_varnames)
|
||||
resolvers = (
|
||||
locals_resolver,
|
||||
ConstModuleNameResolver(globals_dict, as_dict=True),
|
||||
ConstModuleNameResolver(builtins_module),
|
||||
)
|
||||
env = cls(ir_h, name_resolvers=resolvers, **kwargs)
|
||||
|
||||
# Bind parameters.
|
||||
for name, value in parameter_bindings:
|
||||
logging.debug("STORE PARAM: {} <- {}", name, value)
|
||||
locals_resolver.checked_lookup(name).store(env, value)
|
||||
return env
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
for resolver in self._name_resolvers:
|
||||
ref = resolver.lookup(name)
|
||||
if ref is not None:
|
||||
return ref
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Standard name resolvers
|
||||
################################################################################
|
||||
|
||||
|
||||
class LocalNameReference(NameReference):
|
||||
"""Holds an association between a name and SSA value."""
|
||||
__slots__ = [
|
||||
"_current_value",
|
||||
]
|
||||
|
||||
def __init__(self, name, initial_value=None):
|
||||
super().__init__(name)
|
||||
self._current_value = initial_value
|
||||
|
||||
def load(self, env: "Environment") -> PartialEvalResult:
|
||||
if self._current_value is None:
|
||||
return PartialEvalResult.error_message(
|
||||
"Attempt to access local '{}' before assignment".format(self.name))
|
||||
return PartialEvalResult.yields_ir_value(self._current_value)
|
||||
|
||||
def store(self, env: "Environment", value: ir.Value):
|
||||
self._current_value = value
|
||||
|
||||
def __repr__(self):
|
||||
return "<LocalNameReference({})>".format(self.name)
|
||||
|
||||
|
||||
class LocalNameResolver(NameResolver):
|
||||
"""Resolves names in a local cache of SSA values.
|
||||
|
||||
This is used to manage locals and arguments (that are not referenced through
|
||||
a closure).
|
||||
"""
|
||||
__slots__ = [
|
||||
"_name_refs",
|
||||
]
|
||||
|
||||
def __init__(self, names):
|
||||
super().__init__()
|
||||
self._name_refs = {name: LocalNameReference(name) for name in names}
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
return self._name_refs.get(name)
|
||||
|
||||
|
||||
class ConstNameReference(NameReference):
|
||||
"""Represents a name/value mapping that will emit as a constant."""
|
||||
__slots__ = [
|
||||
"_py_value",
|
||||
]
|
||||
|
||||
def __init__(self, name, py_value):
|
||||
super().__init__(name)
|
||||
self._py_value = py_value
|
||||
|
||||
def load(self, env: "Environment") -> PartialEvalResult:
|
||||
return env.partial_eval_hook.resolve(self._py_value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<ConstNameReference({}={})>".format(self.name, self._py_value)
|
||||
|
||||
|
||||
class ConstModuleNameResolver(NameResolver):
|
||||
"""Resolves names from a module by treating them as immutable and loading
|
||||
them as constants into a function scope.
|
||||
"""
|
||||
__slots__ = [
|
||||
"_as_dict",
|
||||
"module",
|
||||
]
|
||||
|
||||
def __init__(self, module, *, as_dict=False):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self._as_dict = as_dict
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
if self._as_dict:
|
||||
if name in self.module:
|
||||
py_value = self.module[name]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
py_value = getattr(self.module, name)
|
||||
except AttributeError:
|
||||
return None
|
||||
return ConstNameReference(name, py_value)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Standard value coders
|
||||
################################################################################
|
||||
|
||||
|
||||
class BuiltinsValueCoder:
|
||||
"""Value coder for builtin python types."""
|
||||
__slots__ = []
|
||||
|
||||
def create_const(self, env: "Environment", py_value):
|
||||
ir_h = env.ir_h
|
||||
ir_c = ir_h.context
|
||||
if py_value is True:
|
||||
return ir_h.basicpy_bool_constant_op(True).result
|
||||
elif py_value is False:
|
||||
return ir_h.basicpy_bool_constant_op(False).result
|
||||
elif py_value is None:
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
elif isinstance(py_value, int):
|
||||
ir_type = env.target.impl_int_type
|
||||
ir_attr = ir_c.integer_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, float):
|
||||
ir_type = env.target.impl_float_type
|
||||
ir_attr = ir_c.float_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, str):
|
||||
return ir_h.basicpy_str_constant_op(py_value).result
|
||||
elif isinstance(py_value, bytes):
|
||||
return ir_h.basicpy_bytes_constant_op(py_value).result
|
||||
elif isinstance(py_value, type(...)):
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class ValueCoderChain(ValueCoder):
|
||||
"""Codes values by delegating to sub-coders in order."""
|
||||
__slots__ = ["_sub_coders"]
|
||||
|
||||
def __init__(self, sub_coders):
|
||||
self._sub_coders = sub_coders
|
||||
|
||||
def create_const(self, env: "Environment", py_value):
|
||||
for sc in self._sub_coders:
|
||||
result = sc.create_const(env, py_value)
|
||||
if result is not NotImplemented:
|
||||
return result
|
||||
return NotImplemented
|
|
@ -14,8 +14,10 @@ from _npcomp.mlir.dialect import ScfDialectHelper
|
|||
from npcomp.dialect import Numpy
|
||||
|
||||
from . import logging
|
||||
from .environment import *
|
||||
from .importer import *
|
||||
from .interfaces import *
|
||||
from .name_resolver_base import *
|
||||
from .value_coder_base import *
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
|
@ -28,26 +30,20 @@ class ImportFrontend:
|
|||
__slots__ = [
|
||||
"_ir_context",
|
||||
"_ir_module",
|
||||
"_helper",
|
||||
"_target_factory",
|
||||
"_value_coder",
|
||||
"_partial_eval_hook",
|
||||
"_ir_h",
|
||||
"_config",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
ir_context: ir.MLIRContext = None,
|
||||
*,
|
||||
target_factory: TargetFactory = GenericTarget64,
|
||||
value_coder: Optional[ValueCoder] = None,
|
||||
partial_eval_hook: Optional[PartialEvalHook] = None):
|
||||
config: Configuration,
|
||||
ir_context: ir.MLIRContext = None):
|
||||
super().__init__()
|
||||
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
||||
self._ir_module = self._ir_context.new_module()
|
||||
self._helper = AllDialectHelper(self._ir_context,
|
||||
ir.OpBuilder(self._ir_context))
|
||||
self._target_factory = target_factory
|
||||
self._value_coder = value_coder if value_coder else BuiltinsValueCoder()
|
||||
self._partial_eval_hook = (partial_eval_hook if partial_eval_hook else
|
||||
build_default_partial_eval_hook())
|
||||
self._ir_h = AllDialectHelper(self._ir_context,
|
||||
ir.OpBuilder(self._ir_context))
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def ir_context(self):
|
||||
|
@ -59,11 +55,7 @@ class ImportFrontend:
|
|||
|
||||
@property
|
||||
def ir_h(self):
|
||||
return self._helper
|
||||
|
||||
@property
|
||||
def partial_eval_hook(self):
|
||||
return self._partial_eval_hook
|
||||
return self._ir_h
|
||||
|
||||
def import_global_function(self, f):
|
||||
"""Imports a global function.
|
||||
|
@ -80,7 +72,7 @@ class ImportFrontend:
|
|||
h = self.ir_h
|
||||
ir_c = self.ir_context
|
||||
ir_m = self.ir_module
|
||||
target = self._target_factory(h)
|
||||
target = self._config.target_factory(h)
|
||||
filename = inspect.getsourcefile(f)
|
||||
source_lines, start_lineno = inspect.getsourcelines(f)
|
||||
source = "".join(source_lines)
|
||||
|
@ -115,26 +107,56 @@ class ImportFrontend:
|
|||
ir_f_type,
|
||||
create_entry_block=True,
|
||||
attrs=attrs)
|
||||
env = Environment.for_const_global_function(
|
||||
h,
|
||||
f,
|
||||
parameter_bindings=zip(f_params.keys(), ir_f.first_block.args),
|
||||
value_coder=self._value_coder,
|
||||
target=target,
|
||||
partial_eval_hook=self._partial_eval_hook)
|
||||
env = self._create_const_global_env(f,
|
||||
parameter_bindings=zip(
|
||||
f_params.keys(),
|
||||
ir_f.first_block.args),
|
||||
target=target)
|
||||
fctx = FunctionContext(ir_c=ir_c,
|
||||
ir_f=ir_f,
|
||||
ir_h=h,
|
||||
filename_ident=filename_ident,
|
||||
target=target,
|
||||
environment=env)
|
||||
|
||||
fdimport = FunctionDefImporter(fctx, ast_fd)
|
||||
fdimport.import_body()
|
||||
return ir_f
|
||||
|
||||
def _create_const_global_env(self, f, parameter_bindings, target):
|
||||
"""Helper to generate an environment for a global function.
|
||||
|
||||
This is a helper for the very common case and will be wholly insufficient
|
||||
for advanced cases, including mutable global state, closures, etc.
|
||||
Globals from the module are considered immutable.
|
||||
"""
|
||||
ir_h = self._ir_h
|
||||
try:
|
||||
code = f.__code__
|
||||
globals_dict = f.__globals__
|
||||
builtins_module = globals_dict["__builtins__"]
|
||||
except AttributeError:
|
||||
assert False, (
|
||||
"Function {} does not have required user-defined function attributes".
|
||||
format(f))
|
||||
|
||||
# Locals resolver.
|
||||
# Note that co_varnames should include both parameter and local names.
|
||||
locals_resolver = LocalNameResolver(code.co_varnames)
|
||||
resolvers = (
|
||||
locals_resolver,
|
||||
ConstModuleNameResolver(globals_dict, as_dict=True),
|
||||
ConstModuleNameResolver(builtins_module),
|
||||
)
|
||||
env = Environment(config=self._config, ir_h=ir_h, name_resolvers=resolvers)
|
||||
|
||||
# Bind parameters.
|
||||
for name, value in parameter_bindings:
|
||||
logging.debug("STORE PARAM: {} <- {}", name, value)
|
||||
locals_resolver.checked_resolve_name(name).store(env, value)
|
||||
return env
|
||||
|
||||
def _resolve_signature_annotation(self, target: Target, annot):
|
||||
ir_h = self._helper
|
||||
ir_h = self._ir_h
|
||||
if annot is inspect.Signature.empty:
|
||||
return ir_h.basicpy_UnknownType
|
||||
|
||||
|
@ -164,20 +186,3 @@ class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
|
|||
def __init__(self, *args, **kwargs):
|
||||
Numpy.DialectHelper.__init__(self, *args, **kwargs)
|
||||
ScfDialectHelper.__init__(self, *args, **kwargs)
|
||||
|
||||
|
||||
def build_default_partial_eval_hook() -> PartialEvalHook:
|
||||
pe = PartialEvalHook()
|
||||
### Modules
|
||||
pe.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary.
|
||||
|
||||
### Tuples
|
||||
# Enable attribute resolution on tuple, which includes namedtuple (which is
|
||||
# really what we want).
|
||||
pe.enable_getattr(for_type=tuple)
|
||||
|
||||
### Temp: resolve a function to a template call for testing
|
||||
import math
|
||||
pe.enable_template_call("__global$math.ceil", for_ref=math.ceil)
|
||||
pe.enable_template_call("__global$math.isclose", for_ref=math.isclose)
|
||||
return pe
|
||||
|
|
|
@ -11,8 +11,7 @@ import traceback
|
|||
from _npcomp.mlir import ir
|
||||
|
||||
from . import logging
|
||||
from .environment import *
|
||||
from .target import *
|
||||
from .interfaces import *
|
||||
|
||||
__all__ = [
|
||||
"FunctionContext",
|
||||
|
@ -27,16 +26,14 @@ class FunctionContext:
|
|||
"ir_c",
|
||||
"ir_f",
|
||||
"ir_h",
|
||||
"target",
|
||||
"filename_ident",
|
||||
"environment",
|
||||
]
|
||||
|
||||
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident, environment):
|
||||
def __init__(self, ir_c, ir_f, ir_h, filename_ident, environment):
|
||||
self.ir_c = ir_c
|
||||
self.ir_f = ir_f
|
||||
self.ir_h = ir_h
|
||||
self.target = target
|
||||
self.filename_ident = filename_ident
|
||||
self.environment = environment
|
||||
|
||||
|
@ -68,7 +65,7 @@ class FunctionContext:
|
|||
|
||||
def lookup_name(self, name) -> NameReference:
|
||||
"""Lookup a name in the environment, requiring it to have evaluated."""
|
||||
ref = self.environment.lookup(name)
|
||||
ref = self.environment.resolve_name(name)
|
||||
if ref is None:
|
||||
self.abort("Could not resolve referenced name '{}'".format(name))
|
||||
logging.debug("Map name({}) -> {}", name, ref)
|
||||
|
@ -77,7 +74,7 @@ class FunctionContext:
|
|||
def emit_const_value(self, py_value) -> ir.Value:
|
||||
"""Codes a value as a constant, returning an ir Value."""
|
||||
env = self.environment
|
||||
result = env.value_coder.create_const(env, py_value)
|
||||
result = env.code_py_value_as_const(py_value)
|
||||
if result is NotImplemented:
|
||||
self.abort("Cannot code python value as constant: {}".format(py_value))
|
||||
return result
|
||||
|
@ -214,7 +211,7 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
|
||||
def emit_constant(self, value):
|
||||
env = self.fctx.environment
|
||||
ir_const_value = env.value_coder.create_const(env, value)
|
||||
ir_const_value = env.code_py_value_as_const(value)
|
||||
if ir_const_value is NotImplemented:
|
||||
self.fctx.abort("unknown constant type '%r'" % (value,))
|
||||
self.value = ir_const_value
|
||||
|
|
|
@ -0,0 +1,295 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Base classes and interfaces."""
|
||||
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
import sys
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
"Configuration",
|
||||
"Environment",
|
||||
"NameReference",
|
||||
"NameResolver",
|
||||
"PartialEvalHook",
|
||||
"PartialEvalType",
|
||||
"PartialEvalResult",
|
||||
"LiveValueRef",
|
||||
"ValueCoder",
|
||||
"ValueCoderChain",
|
||||
]
|
||||
|
||||
_NotImplementedType = type(NotImplemented)
|
||||
|
||||
################################################################################
|
||||
# Name resolution
|
||||
################################################################################
|
||||
|
||||
|
||||
class NameReference:
|
||||
"""Abstract base class for performing operations on a name."""
|
||||
__slots__ = [
|
||||
"name",
|
||||
]
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def load(self, env: "Environment",
|
||||
ir_h: ir.DialectHelper) -> "PartialEvalResult":
|
||||
"""Loads the IR Value associated with the name.
|
||||
|
||||
The load may either be direct, returning an existing value or
|
||||
side-effecting, causing a read from an external context.
|
||||
|
||||
Args:
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Returns:
|
||||
A partial evaluation result.
|
||||
"""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
|
||||
"""Stores a new value into the name.
|
||||
|
||||
A subsequent call to 'load' should yield the same value, subject to
|
||||
typing constraints on value equality.
|
||||
|
||||
Args:
|
||||
value: The new value to store into the name.
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Raises:
|
||||
NotImplementedError if store is not supported for this name.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NameResolver:
|
||||
"""Abstract base class that can resolve a name.
|
||||
|
||||
Name resolvers are typically stacked.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
||||
def checked_resolve_name(self, name: str) -> Optional[NameReference]:
|
||||
ref = self.resolve_name(name)
|
||||
assert ref is not None, "Lookup of name {} is required".format(name)
|
||||
return ref
|
||||
|
||||
def resolve_name(self, name: str) -> Optional[NameReference]:
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Value coding
|
||||
# Transforms python values into IR values.
|
||||
################################################################################
|
||||
|
||||
|
||||
class ValueCoder:
|
||||
"""Encodes values in various ways.
|
||||
|
||||
Instances are designed to be daisy-chained and should ignore types that they
|
||||
don't understand. Functions return NotImplemented if they cannot handle a
|
||||
case locally.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
||||
def code_py_value_as_const(self, env: "Environment",
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class ValueCoderChain(ValueCoder):
|
||||
"""Codes values by delegating to sub-coders in order."""
|
||||
__slots__ = ["_sub_coders"]
|
||||
|
||||
def __init__(self, sub_coders: Sequence[ValueCoder]):
|
||||
self._sub_coders = tuple(sub_coders)
|
||||
|
||||
def code_py_value_as_const(self, env: "Environment",
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
for sc in self._sub_coders:
|
||||
result = sc.code_py_value_as_const(env, py_value)
|
||||
if result is not NotImplemented:
|
||||
return result
|
||||
return NotImplemented
|
||||
|
||||
|
||||
################################################################################
|
||||
# Partial evaluation
|
||||
# When the compiler is extracting from a running program, it is likely that
|
||||
# evaluations produce live values which can be further partially evaluated
|
||||
# at import time, in the context of the running instance (versus emitting
|
||||
# program IR to do so). This behavior is controlled through a PartialEvalHook
|
||||
# on the environment.
|
||||
################################################################################
|
||||
|
||||
|
||||
class PartialEvalType(Enum):
|
||||
# Could not be evaluated immediately and the operation should be
|
||||
# code-generated. yields NotImplemented.
|
||||
NOT_EVALUATED = 0
|
||||
|
||||
# Yields a LiveValueRef
|
||||
YIELDS_LIVE_VALUE = 1
|
||||
|
||||
# Yields an IR value
|
||||
YIELDS_IR_VALUE = 2
|
||||
|
||||
# Evaluation yielded an error (yields contains exc_info from sys.exc_info()).
|
||||
ERROR = 3
|
||||
|
||||
|
||||
class PartialEvalResult(namedtuple("PartialEvalResult", "type,yields")):
|
||||
"""Encapsulates the result of a partial evaluation."""
|
||||
|
||||
@staticmethod
|
||||
def not_evaluated() -> "PartialEvalResult":
|
||||
return PartialEvalResult(PartialEvalType.NOT_EVALUATED, NotImplemented)
|
||||
|
||||
@staticmethod
|
||||
def yields_live_value(live_value) -> "PartialEvalResult":
|
||||
assert isinstance(live_value, LiveValueRef)
|
||||
return PartialEvalResult(PartialEvalType.YIELDS_LIVE_VALUE, live_value)
|
||||
|
||||
@staticmethod
|
||||
def yields_ir_value(ir_value: ir.Value) -> "PartialEvalResult":
|
||||
assert isinstance(ir_value, ir.Value)
|
||||
return PartialEvalResult(PartialEvalType.YIELDS_IR_VALUE, ir_value)
|
||||
|
||||
@staticmethod
|
||||
def error() -> "PartialEvalResult":
|
||||
return PartialEvalResult(PartialEvalType.ERROR, sys.exc_info())
|
||||
|
||||
@staticmethod
|
||||
def error_message(message: str) -> "PartialEvalResult":
|
||||
try:
|
||||
raise RuntimeError(message)
|
||||
except RuntimeError:
|
||||
return PartialEvalResult.error()
|
||||
|
||||
|
||||
class LiveValueRef:
|
||||
"""Wraps a live value from the containing environment.
|
||||
|
||||
Typically, when expressions encounter a live value, a limited number of
|
||||
partial evaluations can be done against it in place (versus emitting the code
|
||||
to import it and perform the operation). This default base class will not
|
||||
perform any static evaluations.
|
||||
"""
|
||||
__slots__ = [
|
||||
"live_value",
|
||||
]
|
||||
|
||||
def __init__(self, live_value):
|
||||
super().__init__()
|
||||
self.live_value = live_value
|
||||
|
||||
def resolve_getattr(self, env: "Environment",
|
||||
attr_name: str) -> PartialEvalResult:
|
||||
"""Gets a named attribute from the live value."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
keywords: Sequence[str]) -> PartialEvalResult:
|
||||
"""Resolves a function call given 'args' and 'keywords'."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def __repr__(self):
|
||||
return "MacroValueRef({}, {})".format(self.__class__.__name__,
|
||||
self.live_value)
|
||||
|
||||
|
||||
class PartialEvalHook:
|
||||
"""Hook interface for performing partial evaluation."""
|
||||
__slots__ = []
|
||||
|
||||
def partial_evaluate(self, py_value) -> PartialEvalResult:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
################################################################################
|
||||
# Configuration and environment
|
||||
################################################################################
|
||||
|
||||
|
||||
class Configuration:
|
||||
"""Base class providing global configuration objects."""
|
||||
__slots__ = [
|
||||
"target_factory",
|
||||
"base_name_resolvers",
|
||||
"value_coder",
|
||||
"partial_eval_hook",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
target_factory: TargetFactory,
|
||||
base_name_resolvers: Sequence[NameResolver] = (),
|
||||
value_coder: Optional[ValueCoder] = None,
|
||||
partial_eval_hook: PartialEvalHook = None):
|
||||
super().__init__()
|
||||
self.target_factory = target_factory
|
||||
self.base_name_resolvers = tuple(base_name_resolvers)
|
||||
self.value_coder = value_coder if value_coder else ValueCoderChain(())
|
||||
self.partial_eval_hook = partial_eval_hook
|
||||
|
||||
def __repr__(self):
|
||||
return ("Configuration(target_factory={}, base_name_resolvers={}, "
|
||||
"value_code={}, partial_eval_hook={})").format(
|
||||
self.target_factory, self.base_name_resolvers, self.value_coder,
|
||||
self.partial_eval_hook)
|
||||
|
||||
|
||||
class Environment:
|
||||
"""Instantiated configuration for emitting code in a specific context.
|
||||
|
||||
This brings together:
|
||||
- The code generation context (ir_h)
|
||||
- An instantiated target
|
||||
- Delegating interfaces for other configuration objects.
|
||||
|
||||
Note that this class does not actually implement most of the delegate
|
||||
interfaces because it hides the fact that some may require more obtuse
|
||||
APIs than should be exposed to end callers (i.e. expecting environment or
|
||||
other config objects).
|
||||
"""
|
||||
__slots__ = [
|
||||
"config",
|
||||
"ir_h",
|
||||
"_name_resolvers",
|
||||
"target",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
config: Configuration,
|
||||
ir_h: ir.DialectHelper,
|
||||
name_resolvers: Sequence[NameResolver] = ()):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ir_h = ir_h
|
||||
self.target = config.target_factory(self.ir_h)
|
||||
self._name_resolvers = (tuple(name_resolvers) +
|
||||
self.config.base_name_resolvers)
|
||||
|
||||
def resolve_name(self, name: str) -> Optional[NameReference]:
|
||||
for resolver in self._name_resolvers:
|
||||
ref = resolver.resolve_name(name)
|
||||
if ref is not None:
|
||||
return ref
|
||||
return None
|
||||
|
||||
def partial_evaluate(self, py_value) -> PartialEvalResult:
|
||||
return self.config.partial_eval_hook.partial_evaluate(py_value)
|
||||
|
||||
def code_py_value_as_const(self,
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
return self.config.value_coder.code_py_value_as_const(self, py_value)
|
|
@ -0,0 +1,114 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Name resolvers for common scenarios."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
from .interfaces import *
|
||||
|
||||
__all__ = [
|
||||
"ConstModuleNameResolver",
|
||||
"LocalNameResolver",
|
||||
]
|
||||
|
||||
################################################################################
|
||||
# Local name resolution
|
||||
# This is used for local names that can be managed purely as SSA values.
|
||||
################################################################################
|
||||
|
||||
|
||||
class LocalNameReference(NameReference):
|
||||
"""Holds an association between a name and SSA value."""
|
||||
__slots__ = [
|
||||
"_current_value",
|
||||
]
|
||||
|
||||
def __init__(self, name, initial_value=None):
|
||||
super().__init__(name)
|
||||
self._current_value = initial_value
|
||||
|
||||
def load(self, env: Environment) -> PartialEvalResult:
|
||||
if self._current_value is None:
|
||||
return PartialEvalResult.error_message(
|
||||
"Attempt to access local '{}' before assignment".format(self.name))
|
||||
return PartialEvalResult.yields_ir_value(self._current_value)
|
||||
|
||||
def store(self, env: Environment, value: ir.Value):
|
||||
self._current_value = value
|
||||
|
||||
def __repr__(self):
|
||||
return "<LocalNameReference({})>".format(self.name)
|
||||
|
||||
|
||||
class LocalNameResolver(NameResolver):
|
||||
"""Resolves names in a local cache of SSA values.
|
||||
|
||||
This is used to manage locals and arguments (that are not referenced through
|
||||
a closure).
|
||||
"""
|
||||
__slots__ = [
|
||||
"_name_refs",
|
||||
]
|
||||
|
||||
def __init__(self, names):
|
||||
super().__init__()
|
||||
self._name_refs = {name: LocalNameReference(name) for name in names}
|
||||
|
||||
def resolve_name(self, name) -> Optional[NameReference]:
|
||||
return self._name_refs.get(name)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Constant name resolution
|
||||
# For some DSLs, it can be appropriate to treat some containing scopes as
|
||||
# constants. This strategy typically binds to a module and routes loads
|
||||
# through the partial evaluation hook.
|
||||
################################################################################
|
||||
|
||||
|
||||
class ConstNameReference(NameReference):
|
||||
"""Represents a name/value mapping that will emit as a constant."""
|
||||
__slots__ = [
|
||||
"_py_value",
|
||||
]
|
||||
|
||||
def __init__(self, name, py_value):
|
||||
super().__init__(name)
|
||||
self._py_value = py_value
|
||||
|
||||
def load(self, env: Environment) -> PartialEvalResult:
|
||||
return env.partial_evaluate(self._py_value)
|
||||
|
||||
def __repr__(self):
|
||||
return "<ConstNameReference({}={})>".format(self.name, self._py_value)
|
||||
|
||||
|
||||
class ConstModuleNameResolver(NameResolver):
|
||||
"""Resolves names from a module by treating them as immutable and loading
|
||||
them as constants into a function scope.
|
||||
"""
|
||||
__slots__ = [
|
||||
"_as_dict",
|
||||
"module",
|
||||
]
|
||||
|
||||
def __init__(self, module, *, as_dict=False):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self._as_dict = as_dict
|
||||
|
||||
def resolve_name(self, name) -> Optional[NameReference]:
|
||||
if self._as_dict:
|
||||
if name in self.module:
|
||||
py_value = self.module[name]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
py_value = getattr(self.module, name)
|
||||
except AttributeError:
|
||||
return None
|
||||
return ConstNameReference(name, py_value)
|
|
@ -0,0 +1,134 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Partial evaluation helpers and support for built-in and common scenarios."""
|
||||
|
||||
from .interfaces import *
|
||||
from .py_value_utils import *
|
||||
from . import logging
|
||||
|
||||
__all__ = [
|
||||
"MappedPartialEvalHook",
|
||||
"ResolveAttrLiveValueRef",
|
||||
"TemplateCallLiveValueRef",
|
||||
]
|
||||
|
||||
_Unspec = object()
|
||||
|
||||
################################################################################
|
||||
# LiveValueRef specializations for various kinds of access
|
||||
################################################################################
|
||||
|
||||
|
||||
class ResolveAttrLiveValueRef(LiveValueRef):
|
||||
"""Custom LiveValueRef that will resolve attributes via getattr."""
|
||||
__slots__ = []
|
||||
|
||||
def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult:
|
||||
logging.debug("RESOLVE_GETATTR '{}' on {}".format(attr_name,
|
||||
self.live_value))
|
||||
try:
|
||||
attr_py_value = getattr(self.live_value, attr_name)
|
||||
except:
|
||||
return PartialEvalResult.error()
|
||||
return env.partial_evaluate(attr_py_value)
|
||||
|
||||
|
||||
class TemplateCallLiveValueRef(LiveValueRef):
|
||||
"""Custom LiveValueRef that resolves calls to a func_template_call op."""
|
||||
__slots__ = ["callee_name"]
|
||||
|
||||
def __init__(self, callee_name, live_value):
|
||||
super().__init__(live_value)
|
||||
self.callee_name = callee_name
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
keywords) -> PartialEvalResult:
|
||||
linear_args = list(args)
|
||||
kw_arg_names = []
|
||||
for kw_name, kw_value in keywords:
|
||||
kw_arg_names.append(kw_name)
|
||||
linear_args.append(kw_value)
|
||||
|
||||
ir_h = env.ir_h
|
||||
result_ir_value = ir_h.basicpy_func_template_call_op(
|
||||
result_type=ir_h.basicpy_UnknownType,
|
||||
callee_symbol=self.callee_name,
|
||||
args=linear_args,
|
||||
arg_names=kw_arg_names).result
|
||||
return PartialEvalResult.yields_ir_value(result_ir_value)
|
||||
|
||||
|
||||
################################################################################
|
||||
# PartialEvalHook implementations
|
||||
################################################################################
|
||||
|
||||
|
||||
class MappedPartialEvalHook(PartialEvalHook):
|
||||
"""A PartialEvalHook that maps rules to produce live values.
|
||||
|
||||
Internally, this implementation binds a predicate to an action. The predicate
|
||||
can be:
|
||||
- A python value matched by reference or value equality
|
||||
- A type that a value must be an instance of
|
||||
- An arbitrary lambda (should be limited to special cases as it forces
|
||||
a linear scan).
|
||||
|
||||
An action can be one of
|
||||
- A `lambda python_value: PartialEvalResult...`
|
||||
- A PartialEvalResult to directly return
|
||||
- None to indicate that the python value should be processed directly
|
||||
"""
|
||||
__slots__ = [
|
||||
"_value_map",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._value_map = PyValueMap()
|
||||
|
||||
def partial_evaluate(self, py_value) -> PartialEvalResult:
|
||||
"""Performs partial evaluation on a python value."""
|
||||
binding = self._value_map.lookup(py_value)
|
||||
if binding is None:
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: Passthrough", py_value)
|
||||
return PartialEvalResult.yields_live_value(LiveValueRef(py_value))
|
||||
if isinstance(binding, PartialEvalResult):
|
||||
return binding
|
||||
# Attempt to call.
|
||||
try:
|
||||
binding = binding(py_value)
|
||||
assert isinstance(binding, PartialEvalResult), (
|
||||
"Expected PartialEvalResult but got {}".format(binding))
|
||||
logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, binding)
|
||||
return binding
|
||||
except:
|
||||
return PartialEvalResult.error()
|
||||
|
||||
def _bind_action(self,
|
||||
action,
|
||||
*,
|
||||
for_ref=_Unspec,
|
||||
for_type=_Unspec,
|
||||
for_predicate=_Unspec):
|
||||
if for_ref is not _Unspec:
|
||||
self._value_map.bind_reference(for_ref, action)
|
||||
elif for_type is not _Unspec:
|
||||
self._value_map.bind_type(for_type, action)
|
||||
elif for_predicate is not _Unspec:
|
||||
self._value_map.bind_predicate(for_predicate, action)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must specify one of 'for_ref', 'for_type' or 'for_predicate")
|
||||
|
||||
def enable_getattr(self, **kwargs):
|
||||
"""Enables partial evaluation of getattr."""
|
||||
self._bind_action(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
ResolveAttrLiveValueRef(pv)), **kwargs)
|
||||
|
||||
def enable_template_call(self, callee_name, **kwargs):
|
||||
""""Enables a global template call."""
|
||||
self._bind_action(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
TemplateCallLiveValueRef(callee_name, pv)), **kwargs)
|
|
@ -0,0 +1,55 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Various configuration helpers for testing."""
|
||||
|
||||
import ast
|
||||
|
||||
from . import logging
|
||||
from .frontend import *
|
||||
from .interfaces import *
|
||||
from .partial_eval_base import *
|
||||
from .target import *
|
||||
from .value_coder_base import *
|
||||
|
||||
|
||||
def create_import_dump_decorator(*,
|
||||
target_factory: TargetFactory = GenericTarget64
|
||||
):
|
||||
config = create_test_config(target_factory=target_factory)
|
||||
logging.debug("Testing with config: {}", config)
|
||||
|
||||
def decorator(f):
|
||||
fe = ImportFrontend(config=config)
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def create_test_config(target_factory: TargetFactory = GenericTarget64):
|
||||
value_coder = BuiltinsValueCoder()
|
||||
pe_hook = build_default_partial_eval_hook()
|
||||
|
||||
return Configuration(target_factory=target_factory,
|
||||
value_coder=value_coder,
|
||||
partial_eval_hook=pe_hook)
|
||||
|
||||
|
||||
def build_default_partial_eval_hook() -> PartialEvalHook:
|
||||
pe = MappedPartialEvalHook()
|
||||
### Modules
|
||||
pe.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary.
|
||||
|
||||
### Tuples
|
||||
# Enable attribute resolution on tuple, which includes namedtuple (which is
|
||||
# really what we want).
|
||||
pe.enable_getattr(for_type=tuple)
|
||||
|
||||
### Temp: resolve a function to a template call for testing
|
||||
import math
|
||||
pe.enable_template_call("__global$math.ceil", for_ref=math.ceil)
|
||||
pe.enable_template_call("__global$math.isclose", for_ref=math.isclose)
|
||||
return pe
|
|
@ -0,0 +1,47 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""Value coders for built-in and common scenarios."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
from .interfaces import *
|
||||
|
||||
__all__ = [
|
||||
"BuiltinsValueCoder",
|
||||
]
|
||||
|
||||
_NotImplementedType = type(NotImplemented)
|
||||
|
||||
|
||||
class BuiltinsValueCoder(ValueCoder):
|
||||
"""Value coder for builtin python types."""
|
||||
__slots__ = []
|
||||
|
||||
def code_py_value_as_const(self, env: Environment,
|
||||
py_value) -> Union[_NotImplementedType, ir.Value]:
|
||||
ir_h = env.ir_h
|
||||
ir_c = ir_h.context
|
||||
if py_value is True:
|
||||
return ir_h.basicpy_bool_constant_op(True).result
|
||||
elif py_value is False:
|
||||
return ir_h.basicpy_bool_constant_op(False).result
|
||||
elif py_value is None:
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
elif isinstance(py_value, int):
|
||||
ir_type = env.target.impl_int_type
|
||||
ir_attr = ir_c.integer_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, float):
|
||||
ir_type = env.target.impl_float_type
|
||||
ir_attr = ir_c.float_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, str):
|
||||
return ir_h.basicpy_str_constant_op(py_value).result
|
||||
elif isinstance(py_value, bytes):
|
||||
return ir_h.basicpy_bytes_constant_op(py_value).result
|
||||
elif isinstance(py_value, type(...)):
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
||||
return NotImplemented
|
Loading…
Reference in New Issue