Add implicit constant capture.

We want more sophisticated capture later, but this allows basics to function.
pull/1/head
Stella Laurenzo 2020-05-08 17:53:30 -07:00
parent 8ae71a9551
commit f2985e0901
3 changed files with 49 additions and 3 deletions

View File

@ -51,6 +51,7 @@ class FunctionTracer(TraceContext):
"_python_args", "_python_args",
"_result_array_params", "_result_array_params",
"_traced_arrays", "_traced_arrays",
"_external_arrays",
] ]
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction): def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
@ -58,6 +59,7 @@ class FunctionTracer(TraceContext):
self.module_builder = module_builder self.module_builder = module_builder
self.epf = epf self.epf = epf
self._traced_arrays = {} # Mapping of TracedArray to current consumer value self._traced_arrays = {} # Mapping of TracedArray to current consumer value
self._external_arrays = {} # Mapping of id to (ndarray, ir.Value)
self._validate() self._validate()
# Alias some parent members for convenience. # Alias some parent members for convenience.
@ -106,11 +108,29 @@ class FunctionTracer(TraceContext):
self._traced_arrays[traced_array] = value self._traced_arrays[traced_array] = value
def get_traced_array_value(self, traced_array): def get_traced_array_value(self, traced_array):
if not isinstance(traced_array, TracedArray):
# Generic import of external value. For now, we just treat these as
# local consts.
return self._get_external_array_value(traced_array)
traced_value = self._traced_arrays.get(traced_array) traced_value = self._traced_arrays.get(traced_array)
if traced_value is None: if traced_value is None:
raise TracingError("Unregistered traced array: %r", (traced_array,)) raise TracingError("Unregistered traced array: %r", (traced_array,))
return traced_value return traced_value
def _get_external_array_value(self, external_array):
h = self._helper
if not isinstance(external_array, np.ndarray):
raise TracingError("Expected ndarray but got: %r" % (external_array,))
found_it = self._external_arrays.get(id(external_array))
if found_it:
return found_it[1]
# Import it.
dense_attr = h.context.dense_elements_attr(external_array)
const_value = h.constant_op(dense_attr.type, dense_attr).result
self._external_arrays[id(external_array)] = (external_array, const_value)
return const_value
def _validate(self): def _validate(self):
if not all( if not all(
arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args): arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
@ -212,12 +232,14 @@ class FunctionTracer(TraceContext):
def _emit_slice_object(self, slice_object: slice): def _emit_slice_object(self, slice_object: slice):
h = self._helper h = self._helper
def emit_index(index): def emit_index(index):
if index is None: if index is None:
return h.basicpy_singleton_op(h.basicpy_NoneType).result return h.basicpy_singleton_op(h.basicpy_NoneType).result
else: else:
return h.constant_op(h.index_type, return h.constant_op(h.index_type,
h.context.index_attr(int(index))).result h.context.index_attr(int(index))).result
start = emit_index(slice_object.start) start = emit_index(slice_object.start)
stop = emit_index(slice_object.stop) stop = emit_index(slice_object.stop)
step = emit_index(slice_object.step) step = emit_index(slice_object.step)
@ -230,8 +252,8 @@ class FunctionTracer(TraceContext):
slice_tuple = key if isinstance(key, tuple) else (key,) slice_tuple = key if isinstance(key, tuple) else (key,)
# Resolve and emit each slice element. # Resolve and emit each slice element.
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple] slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
result_value = h.numpy_get_slice_op( result_value = h.numpy_get_slice_op(h.unknown_array_type, array_value,
h.unknown_array_type, array_value, *slice_values).result *slice_values).result
result_array = TracedArray(self) result_array = TracedArray(self)
self.set_traced_array(result_array, result_value) self.set_traced_array(result_array, result_value)
return result_array return result_array

View File

@ -0,0 +1,21 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
bias = np.random.uniform(size=(4,)).astype(np.float32)
def constants(a: np.ndarray) -> np.ndarray:
return np.dot(a, weights) + bias
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.constants = constants
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.constants)
print(mb.module.to_asm())

View File

@ -767,6 +767,9 @@ void PyValue::bind(py::module m) {
void PyAttribute::bind(py::module m) { void PyAttribute::bind(py::module m) {
py::class_<PyAttribute>(m, "Attribute") py::class_<PyAttribute>(m, "Attribute")
.def_property_readonly("type", [](PyAttribute &self) -> PyType {
return self.attr.getType();
})
.def("__repr__", [](PyAttribute &self) { .def("__repr__", [](PyAttribute &self) {
std::string res; std::string res;
llvm::raw_string_ostream os(res); llvm::raw_string_ostream os(res);