torch-mlir/python/npcomp/compiler/environment.py

288 lines
8.0 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import inspect
from typing import Optional, Union
from _npcomp.mlir import ir
from . import logging
from .target import *
__all__ = [
"BuiltinsValueCoder",
"Environment",
"NameReference",
"NameResolver",
"ValueCoder",
"ValueCoderChain",
]
class ValueCoder:
"""Encodes values in various ways.
Instances are designed to be daisy-chained and should ignore types that they
don't understand. Functions return NotImplemented if they cannot handle a
case locally.
"""
__slots__ = []
def create_const(self, env: "Environment", py_value):
return NotImplemented
class ValueCoderChain(ValueCoder):
"""Codes values by delegating to sub-coders in order."""
__slots__ = ["_sub_coders"]
def __init__(self, sub_coders):
self._sub_coders = sub_coders
def create_const(self, env: "Environment", py_value):
for sc in self._sub_coders:
result = sc.create_const(env, py_value)
if result is not NotImplemented:
return result
return NotImplemented
class NameReference:
"""Abstract base class for performing operations on a name."""
__slots__ = [
"name",
]
def __init__(self, name):
super().__init__()
self.name = name
def load(self, env: "Environment",
ir_h: ir.DialectHelper) -> Optional[ir.Value]:
"""Loads the IR Value associated with the name.
The load may either be direct, returning an existing value or
side-effecting, causing a read from an external context.
Args:
ir_h: The dialect helper used to emit code.
Returns:
An SSA value containing the resolved value (or None if not bound).
Raises:
NotImplementedError if load is not supported for this name.
"""
raise NotImplementedError()
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
"""Stores a new value into the name.
A subsequent call to 'load' should yield the same value, subject to
typing constraints on value equality.
Args:
value: The new value to store into the name.
ir_h: The dialect helper used to emit code.
Raises:
NotImplementedError if store is not supported for this name.
"""
raise NotImplementedError()
class NameResolver:
"""Abstract base class that can resolve a name.
Name resolvers are typically stacked.
"""
def checked_lookup(self, name):
ref = self.lookup(name)
assert ref is not None, "Lookup of name {} is required".format(name)
return ref
def lookup(self, name) -> Optional[NameReference]:
return None
class Environment(NameResolver):
"""Manages access to the environment of a code region.
This encapsulates name lookup, access to the containing module, etc.
"""
__slots__ = [
"ir_h",
"_name_resolvers",
"target",
"value_coder",
]
def __init__(self,
ir_h: ir.DialectHelper,
*,
target: Target,
name_resolvers=(),
value_coder):
super().__init__()
self.ir_h = ir_h
self.target = target
self._name_resolvers = name_resolvers
self.value_coder = value_coder
@classmethod
def for_const_global_function(cls, ir_h: ir.DialectHelper, f, *,
parameter_bindings, **kwargs):
"""Helper to generate an environment for a global function.
This is a helper for the very common case and will be wholly insufficient
for advanced cases, including mutable global state, closures, etc.
Globals from the module are considered immutable.
"""
try:
code = f.__code__
globals_dict = f.__globals__
builtins_module = globals_dict["__builtins__"]
except AttributeError:
assert False, (
"Function {} does not have required user-defined function attributes".
format(f))
# Locals resolver.
# Note that co_varnames should include both parameter and local names.
locals_resolver = LocalNameResolver(code.co_varnames)
resolvers = (
locals_resolver,
ConstModuleNameResolver(globals_dict, as_dict=True),
ConstModuleNameResolver(builtins_module),
)
env = cls(ir_h, name_resolvers=resolvers, **kwargs)
# Bind parameters.
for name, value in parameter_bindings:
logging.debug("STORE PARAM: {} <- {}", name, value)
locals_resolver.checked_lookup(name).store(env, value)
return env
def lookup(self, name) -> Optional[NameReference]:
for resolver in self._name_resolvers:
ref = resolver.lookup(name)
if ref is not None:
return ref
return None
class LocalNameReference(NameReference):
"""Holds an association between a name and SSA value."""
__slots__ = [
"_current_value",
]
def __init__(self, name, initial_value=None):
super().__init__(name)
self._current_value = initial_value
def load(self, env: "Environment") -> Optional[ir.Value]:
return self._current_value
def store(self, env: "Environment", value: ir.Value):
self._current_value = value
def __repr__(self):
return "<LocalNameReference({})>".format(self.name)
class LocalNameResolver(NameResolver):
"""Resolves names in a local cache of SSA values.
This is used to manage locals and arguments (that are not referenced through
a closure).
"""
__slots__ = [
"_name_refs",
]
def __init__(self, names):
super().__init__()
self._name_refs = {name: LocalNameReference(name) for name in names}
def lookup(self, name) -> Optional[NameReference]:
return self._name_refs.get(name)
class ConstNameReference(NameReference):
"""Represents a name/value mapping that will emit as a constant."""
__slots__ = [
"_py_value",
]
def __init__(self, name, py_value):
super().__init__(name)
self._py_value = py_value
def load(self, env: "Environment") -> Optional[ir.Value]:
value = env.value_coder.create_const(env, self._py_value)
if value is NotImplemented:
logging.debug("Unsupported {}", self)
return None
return value
def __repr__(self):
return "<ConstNameReference({}={})>".format(self.name, self._py_value)
class ConstModuleNameResolver(NameResolver):
"""Resolves names from a module by treating them as immutable and loading
them as constants into a function scope.
"""
__slots__ = [
"_as_dict",
"module",
]
def __init__(self, module, *, as_dict=False):
super().__init__()
self.module = module
self._as_dict = as_dict
def lookup(self, name) -> Optional[NameReference]:
if self._as_dict:
if name in self.module:
py_value = self.module[name]
else:
return None
else:
try:
py_value = getattr(self.module, name)
except AttributeError:
return None
return ConstNameReference(name, py_value)
class BuiltinsValueCoder:
"""Value coder for builtin python types."""
__slots__ = []
def create_const(self, env: "Environment", py_value):
ir_h = env.ir_h
ir_c = ir_h.context
if py_value is True:
return ir_h.basicpy_bool_constant_op(True).result
elif py_value is False:
return ir_h.basicpy_bool_constant_op(False).result
elif py_value is None:
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
elif isinstance(py_value, int):
ir_type = env.target.impl_int_type
ir_attr = ir_c.integer_attr(ir_type, py_value)
return ir_h.constant_op(ir_type, ir_attr).result
elif isinstance(py_value, float):
ir_type = env.target.impl_float_type
ir_attr = ir_c.float_attr(ir_type, py_value)
return ir_h.constant_op(ir_type, ir_attr).result
elif isinstance(py_value, str):
return ir_h.basicpy_str_constant_op(py_value).result
elif isinstance(py_value, bytes):
return ir_h.basicpy_bytes_constant_op(py_value).result
elif isinstance(py_value, type(...)):
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
return NotImplemented