torch-mlir/python/npcomp/compiler/numpy/partial_eval_base.py

145 lines
5.0 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Partial evaluation helpers and support for built-in and common scenarios."""
from typing import Any, Callable, Union
from .interfaces import *
from .py_value_utils import *
from ..utils import logging
__all__ = [
"MappedPartialEvalHook",
"ResolveAttrLiveValueRef",
"TemplateCallLiveValueRef",
]
_Unspec = object()
################################################################################
# LiveValueRef specializations for various kinds of access
################################################################################
class ResolveAttrLiveValueRef(LiveValueRef):
"""Custom LiveValueRef that will resolve attributes via getattr."""
__slots__ = []
def resolve_getattr(self, env: "Environment", attr_name) -> PartialEvalResult:
logging.debug("RESOLVE_GETATTR '{}' on {}".format(attr_name,
self.live_value))
try:
attr_py_value = getattr(self.live_value, attr_name)
except:
return PartialEvalResult.error()
return env.partial_evaluate(attr_py_value)
class TemplateCallLiveValueRef(LiveValueRef):
"""Custom LiveValueRef that resolves calls to a func_template_call op."""
__slots__ = ["callee_name"]
def __init__(self, callee_name, live_value):
super().__init__(live_value)
self.callee_name = callee_name
def resolve_call(self, env: "Environment", args,
keywords) -> PartialEvalResult:
linear_args = list(args)
kw_arg_names = []
for kw_name, kw_value in keywords:
kw_arg_names.append(kw_name)
linear_args.append(kw_value)
ir_h = env.ir_h
result_ir_value = ir_h.basicpy_func_template_call_op(
result_type=ir_h.basicpy_UnknownType,
callee_symbol=self.callee_name,
args=linear_args,
arg_names=kw_arg_names).result
return PartialEvalResult.yields_ir_value(result_ir_value)
################################################################################
# PartialEvalHook implementations
################################################################################
class MappedPartialEvalHook(PartialEvalHook):
"""A PartialEvalHook that maps rules to produce live values.
Internally, this implementation binds a predicate to an action. The predicate
can be:
- A python value matched by reference or value equality
- A type that a value must be an instance of
- An arbitrary lambda (should be limited to special cases as it forces
a linear scan).
An action can be one of
- A `lambda python_value: PartialEvalResult...`
- An object that supports as_partial_eval_result() (either a
PartialEvalResult or LiveValueRef qualify).
- None to indicate that the python value should be processed directly
"""
__slots__ = [
"_value_map",
]
def __init__(self):
super().__init__()
self._value_map = PyValueMap()
def __repr__(self):
return "MappedPartialEvalHook({})".format(self._value_map)
def partial_evaluate(self, py_value) -> PartialEvalResult:
"""Performs partial evaluation on a python value."""
logging.debug("LOOKUP: {}", py_value)
action = self._value_map.lookup(py_value)
if action is None:
# Passthrough.
return PartialEvalResult.yields_live_value(LiveValueRef(py_value))
# Attempt to call.
try:
result = action(py_value).as_partial_eval_result()
assert isinstance(result, PartialEvalResult), (
"Expected PartialEvalResult but got {}".format(result))
logging.debug("PARTIAL EVAL RESOLVE {}: {}", py_value, result)
return result
except:
return PartialEvalResult.error()
def bind_action(self,
action: Union[PartialEvalResult, LiveValueRef,
Callable[[Any], PartialEvalResult]],
*,
for_ref=_Unspec,
for_type=_Unspec,
for_predicate=_Unspec):
if hasattr(action, "as_partial_eval_result"):
# Registers a casting action.
action = lambda pv: pv.as_partial_eval_result()
if for_ref is not _Unspec:
self._value_map.bind_reference(for_ref, action)
elif for_type is not _Unspec:
self._value_map.bind_type(for_type, action)
elif for_predicate is not _Unspec:
self._value_map.bind_predicate(for_predicate, action)
else:
raise ValueError(
"Must specify one of 'for_ref', 'for_type' or 'for_predicate")
def enable_getattr(self, **kwargs):
"""Enables partial evaluation of getattr."""
self.bind_action(
lambda pv: PartialEvalResult.yields_live_value(
ResolveAttrLiveValueRef(pv)), **kwargs)
def enable_template_call(self, callee_name, **kwargs):
""""Enables a global template call."""
self.bind_action(
lambda pv: PartialEvalResult.yields_live_value(
TemplateCallLiveValueRef(callee_name, pv)), **kwargs)