torch-mlir/python/npcomp/compiler/extensions/numpy/builtin_ops.py

94 lines
3.2 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Configures evaluation support for numpy builtin ops."""
from typing import Callable, Iterator, Sequence, Tuple
import functools
import numpy as np
from _npcomp.mlir import ir
from ... import logging
from ...interfaces import *
from ...partial_eval_base import *
__all__ = [
"get_ufuncs_from_module",
"bind_ufuncs",
]
################################################################################
# Ufunc evaluation
################################################################################
def _default_ufunc_predicate(ufunc: np.ufunc) -> bool:
"""Filters ufuncs based on ability to evaluate them."""
# Support up to 2 input, 1 output functions.
if ufunc.nin > 2 or ufunc.nout != 1:
return False
return True
def get_ufuncs_from_module(
*,
module=np,
prefix: str = "numpy.",
predicate: Callable[[np.ufunc], bool] = _default_ufunc_predicate,
) -> Iterator[Tuple[str, np.ufunc]]:
"""Iterates over all ufuncs in a module.
Yields:
Tuple of (prefixed_name, ufunc).
"""
ufunc_class = np.ufunc
for local_name in dir(module):
value = getattr(module, local_name)
if isinstance(value, ufunc_class):
if not predicate(value):
logging.debug("Skipped ufunc: {}{} ({})", prefix, local_name, value)
else:
yield (prefix + local_name), value
def bind_ufuncs(ufuncs: Iterator[Tuple[str, np.ufunc]],
pe_hook: MappedPartialEvalHook):
"""Binds a set of ufuncs to a partial eval hook."""
for qualified_name, ufunc in ufuncs:
pe_hook.bind_action(functools.partial(BuiltinUfuncLiveValueRef,
qualified_name, ufunc),
for_ref=ufunc)
class BuiltinUfuncLiveValueRef(LiveValueRef):
"""A partial evaluation that emits IR for invoking a ufunc."""
__slots__ = ["_qualified_name", "_ufunc"]
def __init__(self, qualified_name: str, ufunc: np.ufunc, live_value):
super().__init__(live_value)
self._qualified_name = qualified_name
self._ufunc = ufunc
def resolve_call(self, env: Environment, args: Sequence[ir.Value],
keywords: Sequence[str]) -> PartialEvalResult:
if keywords:
return PartialEvalResult.error_message(
"ufunc call does not currently support keyword args")
if len(args) != self._ufunc.nin:
return PartialEvalResult.error_message(
"ufunc {} expected {} inputs but got {}".format(
self._qualified_name, self._ufunc.nin, len(args)))
ir_h = env.ir_h
# Because a ufunc call is defined in terms of tensors and, at this stage,
# all "public" values are ndarray, do appropriate conversions.
tensor_args = [ir_h.numpy_copy_to_tensor_op(arg).result for arg in args]
result_type = ir_h.numpy_unknown_tensor_type
tensor_result = ir_h.numpy_builtin_ufunc_call_op(
*tensor_args,
qualified_name=self._qualified_name,
result_type=result_type).result
array_result = ir_h.numpy_create_array_from_tensor_op(tensor_result).result
return PartialEvalResult.yields_ir_value(array_result)