mirror of https://github.com/llvm/torch-mlir
Add implicit constant capture.
We want more sophisticated capture later, but this allows basics to function.pull/1/head
parent
8ae71a9551
commit
f2985e0901
|
@ -51,6 +51,7 @@ class FunctionTracer(TraceContext):
|
||||||
"_python_args",
|
"_python_args",
|
||||||
"_result_array_params",
|
"_result_array_params",
|
||||||
"_traced_arrays",
|
"_traced_arrays",
|
||||||
|
"_external_arrays",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction):
|
||||||
|
@ -58,6 +59,7 @@ class FunctionTracer(TraceContext):
|
||||||
self.module_builder = module_builder
|
self.module_builder = module_builder
|
||||||
self.epf = epf
|
self.epf = epf
|
||||||
self._traced_arrays = {} # Mapping of TracedArray to current consumer value
|
self._traced_arrays = {} # Mapping of TracedArray to current consumer value
|
||||||
|
self._external_arrays = {} # Mapping of id to (ndarray, ir.Value)
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
||||||
# Alias some parent members for convenience.
|
# Alias some parent members for convenience.
|
||||||
|
@ -106,11 +108,29 @@ class FunctionTracer(TraceContext):
|
||||||
self._traced_arrays[traced_array] = value
|
self._traced_arrays[traced_array] = value
|
||||||
|
|
||||||
def get_traced_array_value(self, traced_array):
|
def get_traced_array_value(self, traced_array):
|
||||||
|
if not isinstance(traced_array, TracedArray):
|
||||||
|
# Generic import of external value. For now, we just treat these as
|
||||||
|
# local consts.
|
||||||
|
return self._get_external_array_value(traced_array)
|
||||||
|
|
||||||
traced_value = self._traced_arrays.get(traced_array)
|
traced_value = self._traced_arrays.get(traced_array)
|
||||||
if traced_value is None:
|
if traced_value is None:
|
||||||
raise TracingError("Unregistered traced array: %r", (traced_array,))
|
raise TracingError("Unregistered traced array: %r", (traced_array,))
|
||||||
return traced_value
|
return traced_value
|
||||||
|
|
||||||
|
def _get_external_array_value(self, external_array):
|
||||||
|
h = self._helper
|
||||||
|
if not isinstance(external_array, np.ndarray):
|
||||||
|
raise TracingError("Expected ndarray but got: %r" % (external_array,))
|
||||||
|
found_it = self._external_arrays.get(id(external_array))
|
||||||
|
if found_it:
|
||||||
|
return found_it[1]
|
||||||
|
# Import it.
|
||||||
|
dense_attr = h.context.dense_elements_attr(external_array)
|
||||||
|
const_value = h.constant_op(dense_attr.type, dense_attr).result
|
||||||
|
self._external_arrays[id(external_array)] = (external_array, const_value)
|
||||||
|
return const_value
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
if not all(
|
if not all(
|
||||||
arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
|
arg.type_class == TypeClass.NdArray for arg in self.epf.sig.args):
|
||||||
|
@ -212,12 +232,14 @@ class FunctionTracer(TraceContext):
|
||||||
|
|
||||||
def _emit_slice_object(self, slice_object: slice):
|
def _emit_slice_object(self, slice_object: slice):
|
||||||
h = self._helper
|
h = self._helper
|
||||||
|
|
||||||
def emit_index(index):
|
def emit_index(index):
|
||||||
if index is None:
|
if index is None:
|
||||||
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
return h.basicpy_singleton_op(h.basicpy_NoneType).result
|
||||||
else:
|
else:
|
||||||
return h.constant_op(h.index_type,
|
return h.constant_op(h.index_type,
|
||||||
h.context.index_attr(int(index))).result
|
h.context.index_attr(int(index))).result
|
||||||
|
|
||||||
start = emit_index(slice_object.start)
|
start = emit_index(slice_object.start)
|
||||||
stop = emit_index(slice_object.stop)
|
stop = emit_index(slice_object.stop)
|
||||||
step = emit_index(slice_object.step)
|
step = emit_index(slice_object.step)
|
||||||
|
@ -230,8 +252,8 @@ class FunctionTracer(TraceContext):
|
||||||
slice_tuple = key if isinstance(key, tuple) else (key,)
|
slice_tuple = key if isinstance(key, tuple) else (key,)
|
||||||
# Resolve and emit each slice element.
|
# Resolve and emit each slice element.
|
||||||
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
slice_values = [self._emit_slice_value(elt) for elt in slice_tuple]
|
||||||
result_value = h.numpy_get_slice_op(
|
result_value = h.numpy_get_slice_op(h.unknown_array_type, array_value,
|
||||||
h.unknown_array_type, array_value, *slice_values).result
|
*slice_values).result
|
||||||
result_array = TracedArray(self)
|
result_array = TracedArray(self)
|
||||||
self.set_traced_array(result_array, result_value)
|
self.set_traced_array(result_array, result_value)
|
||||||
return result_array
|
return result_array
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import npcomp as npc
|
||||||
|
from npcomp.types import *
|
||||||
|
|
||||||
|
weights = np.random.uniform(size=(16, 4)).astype(np.float32)
|
||||||
|
bias = np.random.uniform(size=(4,)).astype(np.float32)
|
||||||
|
|
||||||
|
def constants(a: np.ndarray) -> np.ndarray:
|
||||||
|
return np.dot(a, weights) + bias
|
||||||
|
|
||||||
|
# TODO: Implement subclassing and deriving constraints by run
|
||||||
|
exp = npc.Exporter()
|
||||||
|
exp.constants = constants
|
||||||
|
|
||||||
|
mb = npc.tracing.ModuleBuilder()
|
||||||
|
mb.trace(exp.constants)
|
||||||
|
print(mb.module.to_asm())
|
|
@ -767,6 +767,9 @@ void PyValue::bind(py::module m) {
|
||||||
|
|
||||||
void PyAttribute::bind(py::module m) {
|
void PyAttribute::bind(py::module m) {
|
||||||
py::class_<PyAttribute>(m, "Attribute")
|
py::class_<PyAttribute>(m, "Attribute")
|
||||||
|
.def_property_readonly("type", [](PyAttribute &self) -> PyType {
|
||||||
|
return self.attr.getType();
|
||||||
|
})
|
||||||
.def("__repr__", [](PyAttribute &self) {
|
.def("__repr__", [](PyAttribute &self) {
|
||||||
std::string res;
|
std::string res;
|
||||||
llvm::raw_string_ostream os(res);
|
llvm::raw_string_ostream os(res);
|
||||||
|
|
Loading…
Reference in New Issue