From f2985e090177867f612152db0bd424fd27e295f0 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 8 May 2020 17:53:30 -0700 Subject: [PATCH] Add implicit constant capture. We want more sophisticated capture later, but this allows basics to function. --- python/npcomp/tracing/mlir_trace.py | 28 +++++++++++++++++++++++++--- python/samples/const.py | 21 +++++++++++++++++++++ python_native/MlirIr.cpp | 3 +++ 3 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 python/samples/const.py diff --git a/python/npcomp/tracing/mlir_trace.py b/python/npcomp/tracing/mlir_trace.py index 3c2986ff6..55ee8cb7a 100644 --- a/python/npcomp/tracing/mlir_trace.py +++ b/python/npcomp/tracing/mlir_trace.py @@ -51,6 +51,7 @@ class FunctionTracer(TraceContext): "_python_args", "_result_array_params", "_traced_arrays", + "_external_arrays", ] def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction): @@ -58,6 +59,7 @@ class FunctionTracer(TraceContext): self.module_builder = module_builder self.epf = epf self._traced_arrays = {} # Mapping of TracedArray to current consumer value + self._external_arrays = {} # Mapping of id to (ndarray, ir.Value) self._validate() # Alias some parent members for convenience. @@ -106,11 +108,29 @@ class FunctionTracer(TraceContext): self._traced_arrays[traced_array] = value def get_traced_array_value(self, traced_array): + if not isinstance(traced_array, TracedArray): + # Generic import of external value. For now, we just treat these as + # local consts. + return self._get_external_array_value(traced_array) + traced_value = self._traced_arrays.get(traced_array) if traced_value is None: raise TracingError("Unregistered traced array: %r", (traced_array,)) return traced_value + def _get_external_array_value(self, external_array): + h = self._helper + if not isinstance(external_array, np.ndarray): + raise TracingError("Expected ndarray but got: %r" % (external_array,)) + found_it = self._external_arrays.get(id(external_array)) + if found_it: + return found_it[1] + # Import it. + dense_attr = h.context.dense_elements_attr(external_array) + const_value = h.constant_op(dense_attr.type, dense_attr).result + self._external_arrays[id(external_array)] = (external_array, const_value) + return const_value + def _validate(self): if not all( arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args): @@ -212,12 +232,14 @@ class FunctionTracer(TraceContext): def _emit_slice_object(self, slice_object: slice): h = self._helper + def emit_index(index): if index is None: return h.basicpy_singleton_op(h.basicpy_NoneType).result else: return h.constant_op(h.index_type, - h.context.index_attr(int(index))).result + h.context.index_attr(int(index))).result + start = emit_index(slice_object.start) stop = emit_index(slice_object.stop) step = emit_index(slice_object.step) @@ -230,8 +252,8 @@ class FunctionTracer(TraceContext): slice_tuple = key if isinstance(key, tuple) else (key,) # Resolve and emit each slice element. slice_values = [self._emit_slice_value(elt) for elt in slice_tuple] - result_value = h.numpy_get_slice_op( - h.unknown_array_type, array_value, *slice_values).result + result_value = h.numpy_get_slice_op(h.unknown_array_type, array_value, + *slice_values).result result_array = TracedArray(self) self.set_traced_array(result_array, result_value) return result_array diff --git a/python/samples/const.py b/python/samples/const.py new file mode 100644 index 000000000..fa2256636 --- /dev/null +++ b/python/samples/const.py @@ -0,0 +1,21 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import numpy as np +import npcomp as npc +from npcomp.types import * + +weights = np.random.uniform(size=(16, 4)).astype(np.float32) +bias = np.random.uniform(size=(4,)).astype(np.float32) + +def constants(a: np.ndarray) -> np.ndarray: + return np.dot(a, weights) + bias + +# TODO: Implement subclassing and deriving constraints by run +exp = npc.Exporter() +exp.constants = constants + +mb = npc.tracing.ModuleBuilder() +mb.trace(exp.constants) +print(mb.module.to_asm()) diff --git a/python_native/MlirIr.cpp b/python_native/MlirIr.cpp index c1ed51fe2..870a8c650 100644 --- a/python_native/MlirIr.cpp +++ b/python_native/MlirIr.cpp @@ -767,6 +767,9 @@ void PyValue::bind(py::module m) { void PyAttribute::bind(py::module m) { py::class_(m, "Attribute") + .def_property_readonly("type", [](PyAttribute &self) -> PyType { + return self.attr.getType(); + }) .def("__repr__", [](PyAttribute &self) { std::string res; llvm::raw_string_ostream os(res);