mirror of https://github.com/llvm/torch-mlir
114 lines
3.3 KiB
Python
114 lines
3.3 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
"""Name resolvers for common scenarios."""
|
|
|
|
from typing import Optional
|
|
|
|
from mlir import ir as _ir
|
|
from .interfaces import *
|
|
|
|
__all__ = [
|
|
"ConstModuleNameResolver",
|
|
"LocalNameResolver",
|
|
]
|
|
|
|
################################################################################
|
|
# Local name resolution
|
|
# This is used for local names that can be managed purely as SSA values.
|
|
################################################################################
|
|
|
|
|
|
class LocalNameReference(NameReference):
|
|
"""Holds an association between a name and SSA value."""
|
|
__slots__ = [
|
|
"_current_value",
|
|
]
|
|
|
|
def __init__(self, name, initial_value=None):
|
|
super().__init__(name)
|
|
self._current_value = initial_value
|
|
|
|
def load(self, env: Environment) -> PartialEvalResult:
|
|
if self._current_value is None:
|
|
return PartialEvalResult.error_message(
|
|
"Attempt to access local '{}' before assignment".format(self.name))
|
|
return PartialEvalResult.yields_ir_value(self._current_value)
|
|
|
|
def store(self, env: Environment, value: _ir.Value):
|
|
self._current_value = value
|
|
|
|
def __repr__(self):
|
|
return "<LocalNameReference({})>".format(self.name)
|
|
|
|
|
|
class LocalNameResolver(NameResolver):
|
|
"""Resolves names in a local cache of SSA values.
|
|
|
|
This is used to manage locals and arguments (that are not referenced through
|
|
a closure).
|
|
"""
|
|
__slots__ = [
|
|
"_name_refs",
|
|
]
|
|
|
|
def __init__(self, names):
|
|
super().__init__()
|
|
self._name_refs = {name: LocalNameReference(name) for name in names}
|
|
|
|
def resolve_name(self, name) -> Optional[NameReference]:
|
|
return self._name_refs.get(name)
|
|
|
|
|
|
################################################################################
|
|
# Constant name resolution
|
|
# For some DSLs, it can be appropriate to treat some containing scopes as
|
|
# constants. This strategy typically binds to a module and routes loads
|
|
# through the partial evaluation hook.
|
|
################################################################################
|
|
|
|
|
|
class ConstNameReference(NameReference):
|
|
"""Represents a name/value mapping that will emit as a constant."""
|
|
__slots__ = [
|
|
"_py_value",
|
|
]
|
|
|
|
def __init__(self, name, py_value):
|
|
super().__init__(name)
|
|
self._py_value = py_value
|
|
|
|
def load(self, env: Environment) -> PartialEvalResult:
|
|
return env.partial_evaluate(self._py_value)
|
|
|
|
def __repr__(self):
|
|
return "<ConstNameReference({}={})>".format(self.name, self._py_value)
|
|
|
|
|
|
class ConstModuleNameResolver(NameResolver):
|
|
"""Resolves names from a module by treating them as immutable and loading
|
|
them as constants into a function scope.
|
|
"""
|
|
__slots__ = [
|
|
"_as_dict",
|
|
"module",
|
|
]
|
|
|
|
def __init__(self, module, *, as_dict=False):
|
|
super().__init__()
|
|
self.module = module
|
|
self._as_dict = as_dict
|
|
|
|
def resolve_name(self, name) -> Optional[NameReference]:
|
|
if self._as_dict:
|
|
if name in self.module:
|
|
py_value = self.module[name]
|
|
else:
|
|
return None
|
|
else:
|
|
try:
|
|
py_value = getattr(self.module, name)
|
|
except AttributeError:
|
|
return None
|
|
return ConstNameReference(name, py_value)
|