mirror of https://github.com/llvm/torch-mlir
Implement __array_func__ hook and use it to trace np.dot.
* Creates an abstraction/registry around emitters (intended to generalize to AST compilation as well). * Reworks ufuncs to use the same mechanism as array funcs. * Adds the numpy.dot op.pull/1/head
parent
1f54838d2e
commit
a5f755d406
|
@ -13,6 +13,32 @@ include "NumpyDialect.td"
|
||||||
include "mlir/Interfaces/SideEffects.td"
|
include "mlir/Interfaces/SideEffects.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
// IR casting and conversions
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
|
||||||
|
def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
|
||||||
|
let summary = "Narrows an array to a known type at boundaries.";
|
||||||
|
let description = [{
|
||||||
|
During tracing, specific data types are often unknown. This op generically
|
||||||
|
narrows from an unknown to a known data type at boundaries.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Numpy_AnyArray:$operand
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Numpy_AnyArray:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` functional-type($operand, $result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
// Universal function ops (ufunc)
|
||||||
|
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
|
||||||
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
|
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
|
||||||
let summary = "References a built-in universal function";
|
let summary = "References a built-in universal function";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -69,20 +95,37 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Numpy_Narrow : Numpy_Op<"narrow", []> {
|
//----------------------------------------------------------------------------//
|
||||||
let summary = "Narrows an array to a known type at boundaries.";
|
// Built-in array functions
|
||||||
|
//
|
||||||
|
// These are ops that mirror supported array functions in numpy or related
|
||||||
|
// libraries. Note that there is some evolution happening on the dispatch
|
||||||
|
// mechanism for these.
|
||||||
|
// See: https://numpy.org/neps/nep-0018-array-function-protocol.html
|
||||||
|
// See: https://numpy.org/neps/nep-0037-array-module.html
|
||||||
|
//
|
||||||
|
// Note that operators are in general free to take any arguments, but there
|
||||||
|
// are some conventions that are mirrored here:
|
||||||
|
//
|
||||||
|
// - `out` arguments indicate that the operation should perform a mutation
|
||||||
|
// of a specific array. This is not modeled at the individual op level,
|
||||||
|
// instead producing IR constructs to map the intent.
|
||||||
|
//----------------------------------------------------------------------------//
|
||||||
|
|
||||||
|
def Numpy_DotOp : Numpy_Op<"dot", []> {
|
||||||
|
let summary = "Represents the `numpy.dot` operator";
|
||||||
let description = [{
|
let description = [{
|
||||||
During tracing, specific data types are often unknown. This op generically
|
See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
|
||||||
narrows from an unknown to a known data type at boundaries.
|
|
||||||
}];
|
}];
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Numpy_AnyArray:$operand
|
Numpy_AnyArray:$a,
|
||||||
|
Numpy_AnyArray:$b
|
||||||
);
|
);
|
||||||
let results = (outs
|
let results = (outs
|
||||||
Numpy_AnyArray:$result
|
Numpy_AnyArray:$output
|
||||||
);
|
);
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$operand attr-dict `:` functional-type($operand, $result)
|
operands attr-dict `:` functional-type(operands, $output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
[style]
|
||||||
|
based_on_style = google
|
||||||
|
column_limit = 80
|
||||||
|
indent_width = 2
|
|
@ -104,7 +104,7 @@ _BUILTIN_MODULE_ASM = r"""
|
||||||
numpy.ufunc_return %0 : f32
|
numpy.ufunc_return %0 : f32
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
numpy.generic_ufunc @numpy.multiple (
|
numpy.generic_ufunc @numpy.multiply (
|
||||||
overload(%arg0: i32, %arg1: i32) -> i32 {
|
overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||||
%0 = muli %arg0, %arg1 : i32
|
%0 = muli %arg0, %arg1 : i32
|
||||||
numpy.ufunc_return %0 : i32
|
numpy.ufunc_return %0 : i32
|
||||||
|
|
|
@ -0,0 +1,270 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class Protocol(Enum):
|
||||||
|
UFUNC = 1
|
||||||
|
ARRAY_FUNC = 2
|
||||||
|
|
||||||
|
|
||||||
|
class TraceValueType(Enum):
|
||||||
|
NDARRAY = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TraceValue(
|
||||||
|
namedtuple("TraceValue", ["value", "type"],
|
||||||
|
defaults=(TraceValueType.NDARRAY,))):
|
||||||
|
__slots__ = ()
|
||||||
|
"""A Python value and the trace type that it should correspond to."""
|
||||||
|
|
||||||
|
|
||||||
|
class TraceInvocation(
|
||||||
|
namedtuple("TraceInvocation", ["inputs", "kwargs", "protocol", "method"],
|
||||||
|
defaults=(Protocol.ARRAY_FUNC, "__call__"))):
|
||||||
|
"""An invocation of a single functions.
|
||||||
|
|
||||||
|
This abstracts over both ufuncs and array_funcs, differentiating by the
|
||||||
|
protocol and method.
|
||||||
|
"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
class EmissionRequest(
|
||||||
|
namedtuple("EmissionRequest", ["input_ssa_values", "ops", "types", "extra"],
|
||||||
|
defaults=(None,))):
|
||||||
|
"""Represents the result of processing inputs from an invocation.
|
||||||
|
|
||||||
|
The `input_ssa_values` are mlir.ir.Value instances corresponding to
|
||||||
|
input_trace_values in TraceValueMap.
|
||||||
|
|
||||||
|
The `extra` value is only relevant to the producer and can be used as a
|
||||||
|
blackbox mechanism to transfer un-tracked state from an invocation to
|
||||||
|
emission.
|
||||||
|
|
||||||
|
The `ops` and `types` fields correspond to mlir.ir.Ops and mlir.ir.Types
|
||||||
|
instances respectively.
|
||||||
|
"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
class TraceValueMap(
|
||||||
|
namedtuple("TraceValueMap",
|
||||||
|
["input_trace_values", "result_trace_value_types", "extra"],
|
||||||
|
defaults=(None,))):
|
||||||
|
"""The result of mapping an invocation to corresponding op structure.
|
||||||
|
|
||||||
|
This type associates:
|
||||||
|
- Python (object, TraceValueType) representing invocation inputs that
|
||||||
|
correspond to SSA values in the IR.
|
||||||
|
- TraceValueTypes that are the expected logical result types from the
|
||||||
|
invocation.
|
||||||
|
- 'extra' object that is passed to followon Emitter methods.
|
||||||
|
"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
class FuncEmitter:
|
||||||
|
"""An emitter for an op-like function invocation."""
|
||||||
|
|
||||||
|
def map_invocation(self, trace_invocation: TraceInvocation) -> TraceValueMap:
|
||||||
|
"""Maps from an invocation to EmissionRequest.
|
||||||
|
|
||||||
|
This hook is also responsible for validating the invocation and should
|
||||||
|
raise appropriate user-visible exceptions (i.e. when invoked with incorrect
|
||||||
|
arguments).
|
||||||
|
|
||||||
|
This hook is used to prepare for emission in a define-by-run scenario.
|
||||||
|
Static emission from an AST needs to be prepared via another mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_invocation: An Invocation instance to map.
|
||||||
|
Returns:
|
||||||
|
A TraceValueMap describing the structure of the invocation as mapped
|
||||||
|
to/from IR.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def map_results(self, py_results, extra):
|
||||||
|
"""Maps a list of python results to actual function return values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
py_results: List of python results corresponding to the emitted op
|
||||||
|
results.
|
||||||
|
extra: The extra object returned by map_invocation.
|
||||||
|
Returns:
|
||||||
|
Actual function result. Typically this requires special handling to
|
||||||
|
unpack the result of functions that return 1 item.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def emit(self, request: EmissionRequest):
|
||||||
|
"""Emits IR using the provided ops and types factories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emission_inputs: An EmissionRequest produced by tracing each TraceValue
|
||||||
|
from a previous call to map_invocation and the corresponding extra
|
||||||
|
value.
|
||||||
|
Returns:
|
||||||
|
An iterable of mlir.ir.Value instances representing the outputs of the
|
||||||
|
operation. The `builder` on `ops` must be positioned to consume these
|
||||||
|
values.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class GenericCallUfuncEmitter(FuncEmitter):
|
||||||
|
"""A FuncEmitter for generic ufuncs requiring no special behavior.
|
||||||
|
|
||||||
|
Representation:
|
||||||
|
>>> emitter = GenericCallUfuncEmitter("numpy.add")
|
||||||
|
>>> emitter
|
||||||
|
<ufunc emitter 'numpy.add'>
|
||||||
|
>>> inv = TraceInvocation([1, 2], {}, protocol=Protocol.UFUNC)
|
||||||
|
>>> inputs = emitter.map_invocation(inv)
|
||||||
|
>>> inputs
|
||||||
|
TraceValueMap(input_trace_values=[TraceValue(value=1, type=<TraceValueType.NDARRAY: 1>), TraceValue(value=2, type=<TraceValueType.NDARRAY: 1>)], result_trace_value_types=[<TraceValueType.NDARRAY: 1>], extra=None)
|
||||||
|
|
||||||
|
Error on unsupported kwargs:
|
||||||
|
>>> inv = TraceInvocation([1, 2], {"foobar": 1}, protocol=Protocol.UFUNC)
|
||||||
|
>>> emitter.map_invocation(inv)
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
ValueError: Unexpected keyword args for ufunc numpy.add: foobar
|
||||||
|
|
||||||
|
"""
|
||||||
|
__slots__ = ("_ufunc_name")
|
||||||
|
|
||||||
|
def __init__(self, ufunc_name: str):
|
||||||
|
self._ufunc_name = ufunc_name
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<ufunc emitter '%s'>" % self._ufunc_name
|
||||||
|
|
||||||
|
def map_invocation(self,
|
||||||
|
trace_invocation: TraceInvocation) -> EmissionRequest:
|
||||||
|
assert trace_invocation.protocol == Protocol.UFUNC
|
||||||
|
assert trace_invocation.method == "__call__"
|
||||||
|
if trace_invocation.kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Unexpected keyword args for ufunc %s: %s" %
|
||||||
|
(self._ufunc_name, ", ".join(trace_invocation.kwargs.keys())))
|
||||||
|
# Without above special cases, any positional args map to emission
|
||||||
|
# inputs.
|
||||||
|
return TraceValueMap([
|
||||||
|
TraceValue(i, TraceValueType.NDARRAY) for i in trace_invocation.inputs
|
||||||
|
], [TraceValueType.NDARRAY],
|
||||||
|
extra=None)
|
||||||
|
|
||||||
|
def map_results(self, py_results, extra):
|
||||||
|
# Ufuncs always return one result, so just unpack it.
|
||||||
|
return py_results[0]
|
||||||
|
|
||||||
|
def emit(self, request: EmissionRequest):
|
||||||
|
op_result_type = request.types.tensor(request.types.numpy_any_dtype)
|
||||||
|
call_op = request.ops.numpy_ufunc_call_op(self._ufunc_name, op_result_type,
|
||||||
|
*request.input_ssa_values)
|
||||||
|
return call_op.results
|
||||||
|
|
||||||
|
|
||||||
|
class GenericArrayFuncEmitter(FuncEmitter):
|
||||||
|
"""Emitter for array funcs that don't do anything 'special'."""
|
||||||
|
__slots__ = ("_op_name", "_nresults")
|
||||||
|
|
||||||
|
def __init__(self, op_name: str, nresults: int = 1):
|
||||||
|
self._op_name = op_name
|
||||||
|
self._nresults = nresults
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<array_func emitter '%s'>" % self._op_name
|
||||||
|
|
||||||
|
def map_invocation(self,
|
||||||
|
trace_invocation: TraceInvocation) -> EmissionRequest:
|
||||||
|
assert trace_invocation.protocol == Protocol.ARRAY_FUNC
|
||||||
|
if trace_invocation.method != "__call__":
|
||||||
|
raise NotImplementedError("Only __call__ is supported for %s (got '%s')" %
|
||||||
|
(
|
||||||
|
self._op_name,
|
||||||
|
trace_invocation.method,
|
||||||
|
))
|
||||||
|
if trace_invocation.kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"Unexpected keyword args for %s: %s" %
|
||||||
|
(self._op_name, ", ".join(trace_invocation.kwargs.keys())))
|
||||||
|
# Without above special cases, any positional args map to emission
|
||||||
|
# inputs.
|
||||||
|
return TraceValueMap([
|
||||||
|
TraceValue(i, TraceValueType.NDARRAY) for i in trace_invocation.inputs
|
||||||
|
], [TraceValueType.NDARRAY] * self._nresults,
|
||||||
|
extra=None)
|
||||||
|
|
||||||
|
def map_results(self, py_results, extra):
|
||||||
|
if self._nresults == 1:
|
||||||
|
return py_results[0]
|
||||||
|
else:
|
||||||
|
return tuple(py_results)
|
||||||
|
|
||||||
|
def emit(self, request: EmissionRequest):
|
||||||
|
op_result_types = [request.types.tensor(request.types.numpy_any_dtype)
|
||||||
|
] * self._nresults
|
||||||
|
op = request.ops.op(self._op_name, op_result_types,
|
||||||
|
request.input_ssa_values)
|
||||||
|
return op.results
|
||||||
|
|
||||||
|
|
||||||
|
class EmitterRegistry:
|
||||||
|
"""Registry of known Emitter instances mapped to source function.
|
||||||
|
|
||||||
|
>>> r = EmitterRegistry.create_default()
|
||||||
|
>>> r.lookup_ufunc(np.add, "__call__")
|
||||||
|
<ufunc emitter 'numpy.add'>
|
||||||
|
>>> r.lookup_array_func(np.dot)
|
||||||
|
<array_func emitter 'numpy.dot'>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._ufunc_map = {} # Dictionary of (f, method) -> Emitter
|
||||||
|
self._arrayfunc_map = {} # Dictionary of f -> Emitter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_default(cls):
|
||||||
|
registry = cls()
|
||||||
|
registry.register_defaults()
|
||||||
|
return registry
|
||||||
|
|
||||||
|
def register_ufunc(self, ufunc, method, emitter):
|
||||||
|
# Last registration wins.
|
||||||
|
self._ufunc_map[(ufunc, method)] = emitter
|
||||||
|
|
||||||
|
def register_array_func(self, f, emitter):
|
||||||
|
# Last registration wins.
|
||||||
|
self._arrayfunc_map[f] = emitter
|
||||||
|
|
||||||
|
def lookup_ufunc(self, ufunc, method):
|
||||||
|
return self._ufunc_map.get((ufunc, method))
|
||||||
|
|
||||||
|
def lookup_array_func(self, f):
|
||||||
|
return self._arrayfunc_map.get(f)
|
||||||
|
|
||||||
|
def register_defaults(self):
|
||||||
|
# Find all ufuncs in the numpy module and register by name.
|
||||||
|
for member in sorted(dir(np)):
|
||||||
|
ufunc = getattr(np, member)
|
||||||
|
if isinstance(ufunc, np.ufunc):
|
||||||
|
self.register_ufunc(ufunc, "__call__",
|
||||||
|
GenericCallUfuncEmitter("numpy." + member))
|
||||||
|
# Register generic 1-result array funcs.
|
||||||
|
for f, op_name in ((np.inner, "numpy.inner"), (np.outer, "numpy.outer"),
|
||||||
|
(np.dot, "numpy.dot"), (np.vdot, "numpy.vdot"),
|
||||||
|
(np.linalg.det, "numpy.linalg.det")):
|
||||||
|
self.register_array_func(f, GenericArrayFuncEmitter(op_name))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import doctest
|
||||||
|
doctest.testmod()
|
|
@ -3,26 +3,30 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Iterable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..dialect import Numpy
|
from ..dialect import Numpy
|
||||||
from ..native.mlir import ir
|
from ..native.mlir import ir
|
||||||
|
|
||||||
from .context import *
|
from .context import *
|
||||||
|
from .emitters import *
|
||||||
from ..exporter import *
|
from ..exporter import *
|
||||||
from ..types import *
|
from ..types import *
|
||||||
|
|
||||||
|
|
||||||
class ModuleBuilder:
|
class ModuleBuilder:
|
||||||
"""Builds an MLIR module by tracing functions."""
|
"""Builds an MLIR module by tracing functions."""
|
||||||
def __init__(self, mlir_context=None):
|
|
||||||
self.context = context if mlir_context else ir.MLIRContext()
|
def __init__(self, mlir_context=None, emitter_registry=None):
|
||||||
|
self.context = mlir_context if mlir_context else ir.MLIRContext()
|
||||||
# TODO: Instead of bootstrapping a large module, populate imports
|
# TODO: Instead of bootstrapping a large module, populate imports
|
||||||
# dynamically.
|
# dynamically.
|
||||||
self.module = Numpy.load_builtin_module(self.context)
|
self.module = Numpy.load_builtin_module(self.context)
|
||||||
self.ops = Numpy.Ops(self.context)
|
self.ops = Numpy.Ops(self.context)
|
||||||
self.types = Numpy.Types(self.context)
|
self.types = Numpy.Types(self.context)
|
||||||
|
self.emitters = (emitter_registry
|
||||||
|
if emitter_registry else EmitterRegistry.create_default())
|
||||||
|
|
||||||
def trace(self, export_py_func: ExportPyFunction):
|
def trace(self, export_py_func: ExportPyFunction):
|
||||||
"""Traces and exported python function."""
|
"""Traces and exported python function."""
|
||||||
|
@ -49,6 +53,7 @@ class FunctionTracer(TraceContext):
|
||||||
"_traced_arrays",
|
"_traced_arrays",
|
||||||
"_types",
|
"_types",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
||||||
super().__init__(desc="[trace of %s]" % epf.__name__)
|
super().__init__(desc="[trace of %s]" % epf.__name__)
|
||||||
self.module_builder = module_builder
|
self.module_builder = module_builder
|
||||||
|
@ -65,7 +70,8 @@ class FunctionTracer(TraceContext):
|
||||||
# Extract ArrayParams for all args and results.
|
# Extract ArrayParams for all args and results.
|
||||||
self._args_array_params = [
|
self._args_array_params = [
|
||||||
ArrayParams.from_constraints(arg.constraints)
|
ArrayParams.from_constraints(arg.constraints)
|
||||||
for arg in self.epf.sig.args]
|
for arg in self.epf.sig.args
|
||||||
|
]
|
||||||
self._python_args = [None] * len(self._args_array_params)
|
self._python_args = [None] * len(self._args_array_params)
|
||||||
self._result_array_params = ArrayParams.from_constraints(
|
self._result_array_params = ArrayParams.from_constraints(
|
||||||
self.epf.sig.result.constraints)
|
self.epf.sig.result.constraints)
|
||||||
|
@ -82,9 +88,10 @@ class FunctionTracer(TraceContext):
|
||||||
ops = self._ops
|
ops = self._ops
|
||||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
py_results = (self.epf.pyfunc(*self._python_args),)
|
||||||
if len(py_results) != len(self._f_types):
|
if len(py_results) != len(self._f_types):
|
||||||
raise TracingError(
|
raise TracingError("Traced function returned != %d results: %r" % (
|
||||||
"Traced function returned != %d results: %r" % (
|
len(self._f_types),
|
||||||
len(self._f_types), py_results,))
|
py_results,
|
||||||
|
))
|
||||||
|
|
||||||
# Narrow all results to the declared return types.
|
# Narrow all results to the declared return types.
|
||||||
return_operands = []
|
return_operands = []
|
||||||
|
@ -106,12 +113,12 @@ class FunctionTracer(TraceContext):
|
||||||
return self._traced_arrays.get(traced_array)
|
return self._traced_arrays.get(traced_array)
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
if not all(arg.type_class == TypeClass.NdArray
|
if not all(
|
||||||
for arg in self.epf.sig.args):
|
arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
|
||||||
raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,))
|
raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,))
|
||||||
if not self.epf.sig.result.type_class == TypeClass.NdArray:
|
if not self.epf.sig.result.type_class == TypeClass.NdArray:
|
||||||
raise NotImplementedError("Non NdArray result: %r" % (
|
raise NotImplementedError("Non NdArray result: %r" %
|
||||||
self.epf.sig.result,))
|
(self.epf.sig.result,))
|
||||||
|
|
||||||
def _create_mlir_function(self):
|
def _create_mlir_function(self):
|
||||||
mlir_c = self._mlir_c
|
mlir_c = self._mlir_c
|
||||||
|
@ -119,10 +126,13 @@ class FunctionTracer(TraceContext):
|
||||||
ops = self._ops
|
ops = self._ops
|
||||||
types = self._types
|
types = self._types
|
||||||
epf = self.epf
|
epf = self.epf
|
||||||
f_args = [mlir_c.parse_type(ap.mlir_tensor_type_asm)
|
f_args = [
|
||||||
for ap in self._args_array_params]
|
mlir_c.parse_type(ap.mlir_tensor_type_asm)
|
||||||
f_types = [mlir_c.parse_type(
|
for ap in self._args_array_params
|
||||||
self._result_array_params.mlir_tensor_type_asm)]
|
]
|
||||||
|
f_types = [
|
||||||
|
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
|
||||||
|
]
|
||||||
ops.builder.insert_before_terminator(mlir_m.first_block)
|
ops.builder.insert_before_terminator(mlir_m.first_block)
|
||||||
f_type = types.function(f_args, f_types)
|
f_type = types.function(f_args, f_types)
|
||||||
f = ops.func_op(epf.__name__, f_type, create_entry_block=True)
|
f = ops.func_op(epf.__name__, f_type, create_entry_block=True)
|
||||||
|
@ -136,54 +146,59 @@ class FunctionTracer(TraceContext):
|
||||||
self.set_traced_array(ta, entry_block.args[index])
|
self.set_traced_array(ta, entry_block.args[index])
|
||||||
self._python_args[index] = ta
|
self._python_args[index] = ta
|
||||||
|
|
||||||
|
def _resolve_input_ssa_values(self, trace_values: Iterable[TraceValue]):
|
||||||
|
"""Resolves input python values to SSA values."""
|
||||||
|
ssa_values = []
|
||||||
|
for tv in trace_values:
|
||||||
|
assert tv.type == TraceValueType.NDARRAY, (
|
||||||
|
"Unsupported TraceValueType: %r" % tv.type)
|
||||||
|
ssa_value = self.get_traced_array_value(tv.value)
|
||||||
|
if ssa_value is None:
|
||||||
|
raise TracingError(
|
||||||
|
"Required a traced python NDARRAY but not found: %r" % (tv,))
|
||||||
|
ssa_values.append(ssa_value)
|
||||||
|
return ssa_values
|
||||||
|
|
||||||
|
def _resolve_result_py_values(self,
|
||||||
|
trace_value_types: Iterable[TraceValueType],
|
||||||
|
ssa_values):
|
||||||
|
"""Resolves result SSA values to runtime python values."""
|
||||||
|
assert len(trace_value_types) == len(ssa_values), (
|
||||||
|
"Mismatched emitter declared result types and results")
|
||||||
|
py_values = []
|
||||||
|
for trace_value_type, ssa_value in zip(trace_value_types, ssa_values):
|
||||||
|
assert trace_value_type == TraceValueType.NDARRAY, (
|
||||||
|
"Unsupported TraceValueType: %r" % trace_value_type)
|
||||||
|
py_value = TracedArray(self)
|
||||||
|
self.set_traced_array(py_value, ssa_value)
|
||||||
|
py_values.append(py_value)
|
||||||
|
return py_values
|
||||||
|
|
||||||
|
def _emit_invocation(self, emitter: FuncEmitter, invocation: TraceInvocation):
|
||||||
|
tv_map = emitter.map_invocation(invocation)
|
||||||
|
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
|
||||||
|
request = EmissionRequest(input_ssa_values,
|
||||||
|
ops=self._ops,
|
||||||
|
types=self._types,
|
||||||
|
extra=tv_map.extra)
|
||||||
|
result_ssa_values = emitter.emit(request)
|
||||||
|
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
|
||||||
|
result_ssa_values)
|
||||||
|
return emitter.map_results(py_values, tv_map.extra)
|
||||||
|
|
||||||
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
||||||
if method == "__call__":
|
emitter = self.module_builder.emitters.lookup_ufunc(ufunc, method)
|
||||||
if kwargs:
|
if not emitter:
|
||||||
raise TracingError("Generic ufunc with kwargs not supported %r" % (
|
return NotImplemented
|
||||||
ufunc,))
|
invocation = TraceInvocation(inputs, kwargs, Protocol.UFUNC, method)
|
||||||
|
return self._emit_invocation(emitter, invocation)
|
||||||
|
|
||||||
# Map inputs to TracedArrays.
|
def _handle_array_func(self, func, types, inputs, kwargs):
|
||||||
# TODO: Process captures, promotions, etc.
|
emitter = self.module_builder.emitters.lookup_array_func(func)
|
||||||
op_inputs = []
|
if not emitter:
|
||||||
for py_input in inputs:
|
return NotImplemented
|
||||||
if not isinstance(py_input, TracedArray):
|
invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
|
||||||
raise TracingError("Unsupported ufunc input: %r", (py_input,))
|
return self._emit_invocation(emitter, invocation)
|
||||||
op_input = self.get_traced_array_value(py_input)
|
|
||||||
if op_input is None:
|
|
||||||
raise TracingError("Unregistered traced array: %r", (py_input,))
|
|
||||||
op_inputs.append(op_input)
|
|
||||||
|
|
||||||
# Emit op.
|
|
||||||
types = self._types
|
|
||||||
mlir_m = self._mlir_m
|
|
||||||
callee_symbol = _UFUNC_SYMBOL_MAP.get(ufunc)
|
|
||||||
if not callee_symbol:
|
|
||||||
raise TracingError("Unsupported ufunc: %r" % ufunc)
|
|
||||||
op_result_type = types.tensor(types.numpy_any_dtype)
|
|
||||||
call_op = self._ops.numpy_ufunc_call_op(
|
|
||||||
callee_symbol, op_result_type, *op_inputs)
|
|
||||||
op_result = call_op.results[0]
|
|
||||||
|
|
||||||
# Wrap returns.
|
|
||||||
return_array = TracedArray(self)
|
|
||||||
self.set_traced_array(return_array, op_result)
|
|
||||||
return return_array
|
|
||||||
|
|
||||||
# Unsupported method.
|
|
||||||
raise TracingError("Unsupported ufunc method %r:%r" % (ufunc, method,))
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: There should be an open registry of ufuncs. But for now, just map
|
|
||||||
# introspect the numpy package and record them.
|
|
||||||
def _build_ufunc_symbol_map():
|
|
||||||
d = {}
|
|
||||||
for member in dir(np):
|
|
||||||
ufunc = getattr(np, member)
|
|
||||||
if isinstance(ufunc, np.ufunc):
|
|
||||||
d[ufunc] = "numpy." + member
|
|
||||||
return d
|
|
||||||
|
|
||||||
_UFUNC_SYMBOL_MAP = _build_ufunc_symbol_map()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue