mirror of https://github.com/llvm/torch-mlir
Add implicit constant capture.
We want more sophisticated capture later, but this allows basics to function.pull/1/head
parent
8ae71a9551
commit
f2985e0901
|
@ -51,6 +51,7 @@ class FunctionTracer(TraceContext):
|
|||
"_python_args",
|
||||
"_result_array_params",
|
||||
"_traced_arrays",
|
||||
"_external_arrays",
|
||||
]
|
||||
|
||||
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
||||
|
@ -58,6 +59,7 @@ class FunctionTracer(TraceContext):
|
|||
self.module_builder = module_builder
|
||||
self.epf = epf
|
||||
self._traced_arrays = {} # Mapping of TracedArray to current consumer value
|
||||
self._external_arrays = {} # Mapping of id to (ndarray, ir.Value)
|
||||
self._validate()
|
||||
|
||||
# Alias some parent members for convenience.
|
||||
|
@ -106,11 +108,29 @@ class FunctionTracer(TraceContext):
|
|||
self._traced_arrays[traced_array] = value
|
||||
|
||||
def get_traced_array_value(self, traced_array):
|
||||
if not isinstance(traced_array, TracedArray):
|
||||
# Generic import of external value. For now, we just treat these as
|
||||
# local consts.
|
||||
return self._get_external_array_value(traced_array)
|
||||
|
||||
traced_value = self._traced_arrays.get(traced_array)
|
||||
if traced_value is None:
|
||||
raise TracingError("Unregistered traced array: %r", (traced_array,))
|
||||
return traced_value
|
||||
|
||||
def _get_external_array_value(self, external_array):
|
||||
h = self._helper
|
||||
if not isinstance(external_array, np.ndarray):
|
||||
raise TracingError("Expected ndarray but got: %r" % (external_array,))
|
||||
found_it = self._external_arrays.get(id(external_array))
|
||||
if found_it:
|
||||
return found_it[1]
|
||||
# Import it.
|
||||
dense_attr = h.context.dense_elements_attr(external_array)
|
||||
const_value = h.constant_op(dense_attr.type, dense_attr).result
|
||||
self._external_arrays[id(external_array)] = (external_array, const_value)
|
||||
return const_value
|
||||
|
||||
def _validate(self):
|
||||
if not all(
|
||||
arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
|
||||
|
@ -212,12 +232,14 @@ class FunctionTracer(TraceContext):
|
|||
|
||||
def _emit_slice_object(self, slice_object: slice):
|
||||
h = self._helper
|
||||
|
||||
def emit_index(index):
|
||||
if index is None:
|
||||
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||
else:
|
||||
return h.constant_op(h.index_type,
|
||||
h.context.index_attr(int(index))).result
|
||||
h.context.index_attr(int(index))).result
|
||||
|
||||
start = emit_index(slice_object.start)
|
||||
stop = emit_index(slice_object.stop)
|
||||
step = emit_index(slice_object.step)
|
||||
|
@ -230,8 +252,8 @@ class FunctionTracer(TraceContext):
|
|||
slice_tuple = key if isinstance(key, tuple) else (key,)
|
||||
# Resolve and emit each slice element.
|
||||
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
||||
result_value = h.numpy_get_slice_op(
|
||||
h.unknown_array_type, array_value, *slice_values).result
|
||||
result_value = h.numpy_get_slice_op(h.unknown_array_type, array_value,
|
||||
*slice_values).result
|
||||
result_array = TracedArray(self)
|
||||
self.set_traced_array(result_array, result_value)
|
||||
return result_array
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
|
||||
bias = np.random.uniform(size=(4,)).astype(np.float32)
|
||||
|
||||
def constants(a: np.ndarray) -> np.ndarray:
|
||||
return np.dot(a, weights) + bias
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.constants = constants
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.constants)
|
||||
print(mb.module.to_asm())
|
|
@ -767,6 +767,9 @@ void PyValue::bind(py::module m) {
|
|||
|
||||
void PyAttribute::bind(py::module m) {
|
||||
py::class_<PyAttribute>(m, "Attribute")
|
||||
.def_property_readonly("type", [](PyAttribute &self) -> PyType {
|
||||
return self.attr.getType();
|
||||
})
|
||||
.def("__repr__", [](PyAttribute &self) {
|
||||
std::string res;
|
||||
llvm::raw_string_ostream os(res);
|
||||
|
|
Loading…
Reference in New Issue