Implement __array_func__ hook and use it to trace np.dot.

* Creates an abstraction/registry around emitters (intended to generalize to AST compilation as well).
* Reworks ufuncs to use the same mechanism as array funcs.
* Adds the numpy.dot op.
pull/1/head
Stella Laurenzo 2020-05-04 15:47:01 -07:00
parent 1f54838d2e
commit a5f755d406
5 changed files with 420 additions and 88 deletions

View File

@ -13,6 +13,32 @@ include "NumpyDialect.td"
include "mlir/Interfaces/SideEffects.td" include "mlir/Interfaces/SideEffects.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
//----------------------------------------------------------------------------//
// IR casting and conversions
//----------------------------------------------------------------------------//
def Numpy_NarrowOp : Numpy_Op<"narrow", []> {
let summary = "Narrows an array to a known type at boundaries.";
let description = [{
During tracing, specific data types are often unknown. This op generically
narrows from an unknown to a known data type at boundaries.
}];
let arguments = (ins
Numpy_AnyArray:$operand
);
let results = (outs
Numpy_AnyArray:$result
);
let assemblyFormat = [{
$operand attr-dict `:` functional-type($operand, $result)
}];
}
//----------------------------------------------------------------------------//
// Universal function ops (ufunc)
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
//----------------------------------------------------------------------------//
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> { def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
let summary = "References a built-in universal function"; let summary = "References a built-in universal function";
let description = [{ let description = [{
@ -69,20 +95,37 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
}]; }];
} }
def Numpy_Narrow : Numpy_Op<"narrow", []> { //----------------------------------------------------------------------------//
let summary = "Narrows an array to a known type at boundaries."; // Built-in array functions
//
// These are ops that mirror supported array functions in numpy or related
// libraries. Note that there is some evolution happening on the dispatch
// mechanism for these.
// See: https://numpy.org/neps/nep-0018-array-function-protocol.html
// See: https://numpy.org/neps/nep-0037-array-module.html
//
// Note that operators are in general free to take any arguments, but there
// are some conventions that are mirrored here:
//
// - `out` arguments indicate that the operation should perform a mutation
// of a specific array. This is not modeled at the individual op level,
// instead producing IR constructs to map the intent.
//----------------------------------------------------------------------------//
def Numpy_DotOp : Numpy_Op<"dot", []> {
let summary = "Represents the `numpy.dot` operator";
let description = [{ let description = [{
During tracing, specific data types are often unknown. This op generically See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
narrows from an unknown to a known data type at boundaries.
}]; }];
let arguments = (ins let arguments = (ins
Numpy_AnyArray:$operand Numpy_AnyArray:$a,
Numpy_AnyArray:$b
); );
let results = (outs let results = (outs
Numpy_AnyArray:$result Numpy_AnyArray:$output
); );
let assemblyFormat = [{ let assemblyFormat = [{
$operand attr-dict `:` functional-type($operand, $result) operands attr-dict `:` functional-type(operands, $output)
}]; }];
} }

View File

@ -0,0 +1,4 @@
[style]
based_on_style = google
column_limit = 80
indent_width = 2

View File

@ -104,7 +104,7 @@ _BUILTIN_MODULE_ASM = r"""
numpy.ufunc_return %0 : f32 numpy.ufunc_return %0 : f32
} }
) )
numpy.generic_ufunc @numpy.multiple ( numpy.generic_ufunc @numpy.multiply (
overload(%arg0: i32, %arg1: i32) -> i32 { overload(%arg0: i32, %arg1: i32) -> i32 {
%0 = muli %arg0, %arg1 : i32 %0 = muli %arg0, %arg1 : i32
numpy.ufunc_return %0 : i32 numpy.ufunc_return %0 : i32

View File

@ -0,0 +1,270 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
from collections import namedtuple
from enum import Enum
class Protocol(Enum):
UFUNC = 1
ARRAY_FUNC = 2
class TraceValueType(Enum):
NDARRAY = 1
class TraceValue(
namedtuple("TraceValue", ["value", "type"],
defaults=(TraceValueType.NDARRAY,))):
__slots__ = ()
"""A Python value and the trace type that it should correspond to."""
class TraceInvocation(
namedtuple("TraceInvocation", ["inputs", "kwargs", "protocol", "method"],
defaults=(Protocol.ARRAY_FUNC, "__call__"))):
"""An invocation of a single functions.
This abstracts over both ufuncs and array_funcs, differentiating by the
protocol and method.
"""
__slots__ = ()
class EmissionRequest(
namedtuple("EmissionRequest", ["input_ssa_values", "ops", "types", "extra"],
defaults=(None,))):
"""Represents the result of processing inputs from an invocation.
The `input_ssa_values` are mlir.ir.Value instances corresponding to
input_trace_values in TraceValueMap.
The `extra` value is only relevant to the producer and can be used as a
blackbox mechanism to transfer un-tracked state from an invocation to
emission.
The `ops` and `types` fields correspond to mlir.ir.Ops and mlir.ir.Types
instances respectively.
"""
__slots__ = ()
class TraceValueMap(
namedtuple("TraceValueMap",
["input_trace_values", "result_trace_value_types", "extra"],
defaults=(None,))):
"""The result of mapping an invocation to corresponding op structure.
This type associates:
- Python (object, TraceValueType) representing invocation inputs that
correspond to SSA values in the IR.
- TraceValueTypes that are the expected logical result types from the
invocation.
- 'extra' object that is passed to followon Emitter methods.
"""
__slots__ = ()
class FuncEmitter:
"""An emitter for an op-like function invocation."""
def map_invocation(self, trace_invocation: TraceInvocation) -> TraceValueMap:
"""Maps from an invocation to EmissionRequest.
This hook is also responsible for validating the invocation and should
raise appropriate user-visible exceptions (i.e. when invoked with incorrect
arguments).
This hook is used to prepare for emission in a define-by-run scenario.
Static emission from an AST needs to be prepared via another mechanism.
Args:
trace_invocation: An Invocation instance to map.
Returns:
A TraceValueMap describing the structure of the invocation as mapped
to/from IR.
"""
raise NotImplementedError()
def map_results(self, py_results, extra):
"""Maps a list of python results to actual function return values.
Args:
py_results: List of python results corresponding to the emitted op
results.
extra: The extra object returned by map_invocation.
Returns:
Actual function result. Typically this requires special handling to
unpack the result of functions that return 1 item.
"""
raise NotImplementedError()
def emit(self, request: EmissionRequest):
"""Emits IR using the provided ops and types factories.
Args:
emission_inputs: An EmissionRequest produced by tracing each TraceValue
from a previous call to map_invocation and the corresponding extra
value.
Returns:
An iterable of mlir.ir.Value instances representing the outputs of the
operation. The `builder` on `ops` must be positioned to consume these
values.
"""
raise NotImplementedError()
class GenericCallUfuncEmitter(FuncEmitter):
"""A FuncEmitter for generic ufuncs requiring no special behavior.
Representation:
>>> emitter = GenericCallUfuncEmitter("numpy.add")
>>> emitter
<ufunc emitter 'numpy.add'>
>>> inv = TraceInvocation([1, 2], {}, protocol=Protocol.UFUNC)
>>> inputs = emitter.map_invocation(inv)
>>> inputs
TraceValueMap(input_trace_values=[TraceValue(value=1, type=<TraceValueType.NDARRAY: 1>), TraceValue(value=2, type=<TraceValueType.NDARRAY: 1>)], result_trace_value_types=[<TraceValueType.NDARRAY: 1>], extra=None)
Error on unsupported kwargs:
>>> inv = TraceInvocation([1, 2], {"foobar": 1}, protocol=Protocol.UFUNC)
>>> emitter.map_invocation(inv)
Traceback (most recent call last):
...
ValueError: Unexpected keyword args for ufunc numpy.add: foobar
"""
__slots__ = ("_ufunc_name")
def __init__(self, ufunc_name: str):
self._ufunc_name = ufunc_name
def __repr__(self):
return "<ufunc emitter '%s'>" % self._ufunc_name
def map_invocation(self,
trace_invocation: TraceInvocation) -> EmissionRequest:
assert trace_invocation.protocol == Protocol.UFUNC
assert trace_invocation.method == "__call__"
if trace_invocation.kwargs:
raise ValueError(
"Unexpected keyword args for ufunc %s: %s" %
(self._ufunc_name, ", ".join(trace_invocation.kwargs.keys())))
# Without above special cases, any positional args map to emission
# inputs.
return TraceValueMap([
TraceValue(i, TraceValueType.NDARRAY) for i in trace_invocation.inputs
], [TraceValueType.NDARRAY],
extra=None)
def map_results(self, py_results, extra):
# Ufuncs always return one result, so just unpack it.
return py_results[0]
def emit(self, request: EmissionRequest):
op_result_type = request.types.tensor(request.types.numpy_any_dtype)
call_op = request.ops.numpy_ufunc_call_op(self._ufunc_name, op_result_type,
*request.input_ssa_values)
return call_op.results
class GenericArrayFuncEmitter(FuncEmitter):
"""Emitter for array funcs that don't do anything 'special'."""
__slots__ = ("_op_name", "_nresults")
def __init__(self, op_name: str, nresults: int = 1):
self._op_name = op_name
self._nresults = nresults
def __repr__(self):
return "<array_func emitter '%s'>" % self._op_name
def map_invocation(self,
trace_invocation: TraceInvocation) -> EmissionRequest:
assert trace_invocation.protocol == Protocol.ARRAY_FUNC
if trace_invocation.method != "__call__":
raise NotImplementedError("Only __call__ is supported for %s (got '%s')" %
(
self._op_name,
trace_invocation.method,
))
if trace_invocation.kwargs:
raise ValueError(
"Unexpected keyword args for %s: %s" %
(self._op_name, ", ".join(trace_invocation.kwargs.keys())))
# Without above special cases, any positional args map to emission
# inputs.
return TraceValueMap([
TraceValue(i, TraceValueType.NDARRAY) for i in trace_invocation.inputs
], [TraceValueType.NDARRAY] * self._nresults,
extra=None)
def map_results(self, py_results, extra):
if self._nresults == 1:
return py_results[0]
else:
return tuple(py_results)
def emit(self, request: EmissionRequest):
op_result_types = [request.types.tensor(request.types.numpy_any_dtype)
] * self._nresults
op = request.ops.op(self._op_name, op_result_types,
request.input_ssa_values)
return op.results
class EmitterRegistry:
"""Registry of known Emitter instances mapped to source function.
>>> r = EmitterRegistry.create_default()
>>> r.lookup_ufunc(np.add, "__call__")
<ufunc emitter 'numpy.add'>
>>> r.lookup_array_func(np.dot)
<array_func emitter 'numpy.dot'>
"""
def __init__(self):
self._ufunc_map = {} # Dictionary of (f, method) -> Emitter
self._arrayfunc_map = {} # Dictionary of f -> Emitter
@classmethod
def create_default(cls):
registry = cls()
registry.register_defaults()
return registry
def register_ufunc(self, ufunc, method, emitter):
# Last registration wins.
self._ufunc_map[(ufunc, method)] = emitter
def register_array_func(self, f, emitter):
# Last registration wins.
self._arrayfunc_map[f] = emitter
def lookup_ufunc(self, ufunc, method):
return self._ufunc_map.get((ufunc, method))
def lookup_array_func(self, f):
return self._arrayfunc_map.get(f)
def register_defaults(self):
# Find all ufuncs in the numpy module and register by name.
for member in sorted(dir(np)):
ufunc = getattr(np, member)
if isinstance(ufunc, np.ufunc):
self.register_ufunc(ufunc, "__call__",
GenericCallUfuncEmitter("numpy." + member))
# Register generic 1-result array funcs.
for f, op_name in ((np.inner, "numpy.inner"), (np.outer, "numpy.outer"),
(np.dot, "numpy.dot"), (np.vdot, "numpy.vdot"),
(np.linalg.det, "numpy.linalg.det")):
self.register_array_func(f, GenericArrayFuncEmitter(op_name))
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@ -3,31 +3,35 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import re import re
from typing import Iterable
import numpy as np import numpy as np
from ..dialect import Numpy from ..dialect import Numpy
from ..native.mlir import ir from ..native.mlir import ir
from .context import * from .context import *
from .emitters import *
from ..exporter import * from ..exporter import *
from ..types import * from ..types import *
class ModuleBuilder: class ModuleBuilder:
"""Builds an MLIR module by tracing functions.""" """Builds an MLIR module by tracing functions."""
def __init__(self, mlir_context=None):
self.context = context if mlir_context else ir.MLIRContext() def __init__(self, mlir_context=None, emitter_registry=None):
self.context = mlir_context if mlir_context else ir.MLIRContext()
# TODO: Instead of bootstrapping a large module, populate imports # TODO: Instead of bootstrapping a large module, populate imports
# dynamically. # dynamically.
self.module = Numpy.load_builtin_module(self.context) self.module = Numpy.load_builtin_module(self.context)
self.ops = Numpy.Ops(self.context) self.ops = Numpy.Ops(self.context)
self.types = Numpy.Types(self.context) self.types = Numpy.Types(self.context)
self.emitters = (emitter_registry
if emitter_registry else EmitterRegistry.create_default())
def trace(self, export_py_func: ExportPyFunction): def trace(self, export_py_func: ExportPyFunction):
"""Traces and exported python function.""" """Traces and exported python function."""
assert isinstance(export_py_func, ExportPyFunction), ( assert isinstance(export_py_func, ExportPyFunction), (
"Expected an exported python function (from the Exporter class)") "Expected an exported python function (from the Exporter class)")
tracer = FunctionTracer(self, export_py_func) tracer = FunctionTracer(self, export_py_func)
with tracer: with tracer:
tracer.trace() tracer.trace()
@ -36,19 +40,20 @@ class ModuleBuilder:
class FunctionTracer(TraceContext): class FunctionTracer(TraceContext):
"""A trace of a single function.""" """A trace of a single function."""
__slots__ = [ __slots__ = [
"module_builder", "module_builder",
"epf", "epf",
"_args_array_params", "_args_array_params",
"_f", "_f",
"_f_types", "_f_types",
"_mlir_m", "_mlir_m",
"_mlir_c", "_mlir_c",
"_python_args", "_python_args",
"_ops", "_ops",
"_result_array_params", "_result_array_params",
"_traced_arrays", "_traced_arrays",
"_types", "_types",
] ]
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction): def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
super().__init__(desc="[trace of %s]" % epf.__name__) super().__init__(desc="[trace of %s]" % epf.__name__)
self.module_builder = module_builder self.module_builder = module_builder
@ -64,11 +69,12 @@ class FunctionTracer(TraceContext):
# Extract ArrayParams for all args and results. # Extract ArrayParams for all args and results.
self._args_array_params = [ self._args_array_params = [
ArrayParams.from_constraints(arg.constraints) ArrayParams.from_constraints(arg.constraints)
for arg in self.epf.sig.args] for arg in self.epf.sig.args
]
self._python_args = [None] * len(self._args_array_params) self._python_args = [None] * len(self._args_array_params)
self._result_array_params = ArrayParams.from_constraints( self._result_array_params = ArrayParams.from_constraints(
self.epf.sig.result.constraints) self.epf.sig.result.constraints)
# Create the MLIR function. # Create the MLIR function.
self._f, self._f_types = self._create_mlir_function() self._f, self._f_types = self._create_mlir_function()
@ -82,10 +88,11 @@ class FunctionTracer(TraceContext):
ops = self._ops ops = self._ops
py_results = (self.epf.pyfunc(*self._python_args),) py_results = (self.epf.pyfunc(*self._python_args),)
if len(py_results) != len(self._f_types): if len(py_results) != len(self._f_types):
raise TracingError( raise TracingError("Traced function returned != %d results: %r" % (
"Traced function returned != %d results: %r" % ( len(self._f_types),
len(self._f_types), py_results,)) py_results,
))
# Narrow all results to the declared return types. # Narrow all results to the declared return types.
return_operands = [] return_operands = []
for py_result, mlir_result_type in zip(py_results, self._f_types): for py_result, mlir_result_type in zip(py_results, self._f_types):
@ -94,7 +101,7 @@ class FunctionTracer(TraceContext):
raise TracingError("Unregistered traced array: %r", (py_result,)) raise TracingError("Unregistered traced array: %r", (py_result,))
# narrow to declared result type. # narrow to declared result type.
return_operands.extend( return_operands.extend(
ops.numpy_narrow(mlir_result_type, mlir_result).results) ops.numpy_narrow(mlir_result_type, mlir_result).results)
ops.return_op(return_operands) ops.return_op(return_operands)
def set_traced_array(self, traced_array, value): def set_traced_array(self, traced_array, value):
@ -106,12 +113,12 @@ class FunctionTracer(TraceContext):
return self._traced_arrays.get(traced_array) return self._traced_arrays.get(traced_array)
def _validate(self): def _validate(self):
if not all(arg.type_class == TypeClass.NdArray if not all(
for arg in self.epf.sig.args): arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,)) raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,))
if not self.epf.sig.result.type_class == TypeClass.NdArray: if not self.epf.sig.result.type_class == TypeClass.NdArray:
raise NotImplementedError("Non NdArray result: %r" % ( raise NotImplementedError("Non NdArray result: %r" %
self.epf.sig.result,)) (self.epf.sig.result,))
def _create_mlir_function(self): def _create_mlir_function(self):
mlir_c = self._mlir_c mlir_c = self._mlir_c
@ -119,10 +126,13 @@ class FunctionTracer(TraceContext):
ops = self._ops ops = self._ops
types = self._types types = self._types
epf = self.epf epf = self.epf
f_args = [mlir_c.parse_type(ap.mlir_tensor_type_asm) f_args = [
for ap in self._args_array_params] mlir_c.parse_type(ap.mlir_tensor_type_asm)
f_types = [mlir_c.parse_type( for ap in self._args_array_params
self._result_array_params.mlir_tensor_type_asm)] ]
f_types = [
mlir_c.parse_type(self._result_array_params.mlir_tensor_type_asm)
]
ops.builder.insert_before_terminator(mlir_m.first_block) ops.builder.insert_before_terminator(mlir_m.first_block)
f_type = types.function(f_args, f_types) f_type = types.function(f_args, f_types)
f = ops.func_op(epf.__name__, f_type, create_entry_block=True) f = ops.func_op(epf.__name__, f_type, create_entry_block=True)
@ -136,56 +146,61 @@ class FunctionTracer(TraceContext):
self.set_traced_array(ta, entry_block.args[index]) self.set_traced_array(ta, entry_block.args[index])
self._python_args[index] = ta self._python_args[index] = ta
def _resolve_input_ssa_values(self, trace_values: Iterable[TraceValue]):
"""Resolves input python values to SSA values."""
ssa_values = []
for tv in trace_values:
assert tv.type == TraceValueType.NDARRAY, (
"Unsupported TraceValueType: %r" % tv.type)
ssa_value = self.get_traced_array_value(tv.value)
if ssa_value is None:
raise TracingError(
"Required a traced python NDARRAY but not found: %r" % (tv,))
ssa_values.append(ssa_value)
return ssa_values
def _resolve_result_py_values(self,
trace_value_types: Iterable[TraceValueType],
ssa_values):
"""Resolves result SSA values to runtime python values."""
assert len(trace_value_types) == len(ssa_values), (
"Mismatched emitter declared result types and results")
py_values = []
for trace_value_type, ssa_value in zip(trace_value_types, ssa_values):
assert trace_value_type == TraceValueType.NDARRAY, (
"Unsupported TraceValueType: %r" % trace_value_type)
py_value = TracedArray(self)
self.set_traced_array(py_value, ssa_value)
py_values.append(py_value)
return py_values
def _emit_invocation(self, emitter: FuncEmitter, invocation: TraceInvocation):
tv_map = emitter.map_invocation(invocation)
input_ssa_values = self._resolve_input_ssa_values(tv_map.input_trace_values)
request = EmissionRequest(input_ssa_values,
ops=self._ops,
types=self._types,
extra=tv_map.extra)
result_ssa_values = emitter.emit(request)
py_values = self._resolve_result_py_values(tv_map.result_trace_value_types,
result_ssa_values)
return emitter.map_results(py_values, tv_map.extra)
def _handle_ufunc(self, ufunc, method, inputs, kwargs): def _handle_ufunc(self, ufunc, method, inputs, kwargs):
if method == "__call__": emitter = self.module_builder.emitters.lookup_ufunc(ufunc, method)
if kwargs: if not emitter:
raise TracingError("Generic ufunc with kwargs not supported %r" % ( return NotImplemented
ufunc,)) invocation = TraceInvocation(inputs, kwargs, Protocol.UFUNC, method)
return self._emit_invocation(emitter, invocation)
# Map inputs to TracedArrays.
# TODO: Process captures, promotions, etc.
op_inputs = []
for py_input in inputs:
if not isinstance(py_input, TracedArray):
raise TracingError("Unsupported ufunc input: %r", (py_input,))
op_input = self.get_traced_array_value(py_input)
if op_input is None:
raise TracingError("Unregistered traced array: %r", (py_input,))
op_inputs.append(op_input)
# Emit op.
types = self._types
mlir_m = self._mlir_m
callee_symbol = _UFUNC_SYMBOL_MAP.get(ufunc)
if not callee_symbol:
raise TracingError("Unsupported ufunc: %r" % ufunc)
op_result_type = types.tensor(types.numpy_any_dtype)
call_op = self._ops.numpy_ufunc_call_op(
callee_symbol, op_result_type, *op_inputs)
op_result = call_op.results[0]
# Wrap returns.
return_array = TracedArray(self)
self.set_traced_array(return_array, op_result)
return return_array
# Unsupported method. def _handle_array_func(self, func, types, inputs, kwargs):
raise TracingError("Unsupported ufunc method %r:%r" % (ufunc, method,)) emitter = self.module_builder.emitters.lookup_array_func(func)
if not emitter:
return NotImplemented
# TODO: There should be an open registry of ufuncs. But for now, just map invocation = TraceInvocation(inputs, kwargs, Protocol.ARRAY_FUNC)
# introspect the numpy package and record them. return self._emit_invocation(emitter, invocation)
def _build_ufunc_symbol_map():
d = {}
for member in dir(np):
ufunc = getattr(np, member)
if isinstance(ufunc, np.ufunc):
d[ufunc] = "numpy." + member
return d
_UFUNC_SYMBOL_MAP = _build_ufunc_symbol_map()
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
doctest.testmod() doctest.testmod()