From c89a35f97f807771f5517d0f9261d2122835cc02 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 2 May 2020 19:52:21 -0700 Subject: [PATCH] Rework the poc tracer to be structured how intended. --- include/npcomp/Dialect/Numpy/NumpyOps.td | 17 ++ python/npcomp/dialect/Numpy.py | 28 ++- python/npcomp/exp/extractor.py | 199 ------------------ python/npcomp/tracing/context.py | 34 ++- python/npcomp/tracing/mlir_trace.py | 250 +++++++++++++++-------- python/npcomp/tracing/mlir_trace_test.py | 38 ++++ python/run_tests.py | 2 +- 7 files changed, 269 insertions(+), 299 deletions(-) delete mode 100644 python/npcomp/exp/extractor.py create mode 100644 python/npcomp/tracing/mlir_trace_test.py diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.td b/include/npcomp/Dialect/Numpy/NumpyOps.td index 3a4ef40db..be128e28c 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.td +++ b/include/npcomp/Dialect/Numpy/NumpyOps.td @@ -69,4 +69,21 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> { }]; } +def Numpy_Narrow : Numpy_Op<"narrow", []> { + let summary = "Narrows an array to a known type at boundaries."; + let description = [{ + During tracing, specific data types are often unknown. This op generically + narrows from an unknown to a known data type at boundaries. + }]; + let arguments = (ins + Numpy_AnyArray:$operand + ); + let results = (outs + Numpy_AnyArray:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` functional-type($operand, $result) + }]; +} + #endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS diff --git a/python/npcomp/dialect/Numpy.py b/python/npcomp/dialect/Numpy.py index d19a32c0a..3dfd2589a 100644 --- a/python/npcomp/dialect/Numpy.py +++ b/python/npcomp/dialect/Numpy.py @@ -43,18 +43,22 @@ class Ops(ir.Ops): }) return self.op("numpy.ufunc_call", [result_type], args, attrs) + def numpy_narrow(self, result_type, operand): + """Creates a numpy.narrow op.""" + return self.op("numpy.narrow", [result_type], [operand]) + class Types(ir.Types): """Container/factory for dialect types. >>> t = Types(ir.MLIRContext()) - >>> t.any_dtype + >>> t.numpy_any_dtype !numpy.any_dtype - >>> t.tensor(t.any_dtype, [1, 2, 3]) + >>> t.tensor(t.numpy_any_dtype, [1, 2, 3]) tensor<1x2x3x!numpy.any_dtype> - >>> t.tensor(t.any_dtype) + >>> t.tensor(t.numpy_any_dtype) tensor<*x!numpy.any_dtype> - >>> t.tensor(t.any_dtype, [-1, 2]) + >>> t.tensor(t.numpy_any_dtype, [-1, 2]) tensor >>> t.tensor(t.f32) tensor<*xf32> @@ -64,7 +68,7 @@ class Types(ir.Types): """ def __init__(self, context): super().__init__(context) - self.any_dtype = context.parse_type("!numpy.any_dtype") + self.numpy_any_dtype = context.parse_type("!numpy.any_dtype") def load_builtin_module(context=None): @@ -91,19 +95,25 @@ def load_builtin_module(context=None): _BUILTIN_MODULE_ASM = r""" numpy.generic_ufunc @numpy.add ( - // CHECK-SAME: overload(%arg0: i32, %arg1: i32) -> i32 { overload(%arg0: i32, %arg1: i32) -> i32 { - // CHECK: addi %0 = addi %arg0, %arg1 : i32 numpy.ufunc_return %0 : i32 }, - // CHECK: overload(%arg0: f32, %arg1: f32) -> f32 { overload(%arg0: f32, %arg1: f32) -> f32 { - // CHECK: addf %0 = addf %arg0, %arg1 : f32 numpy.ufunc_return %0 : f32 } ) + numpy.generic_ufunc @numpy.multiple ( + overload(%arg0: i32, %arg1: i32) -> i32 { + %0 = muli %arg0, %arg1 : i32 + numpy.ufunc_return %0 : i32 + }, + overload(%arg0: f32, %arg1: f32) -> f32 { + %0 = mulf %arg0, %arg1 : f32 + numpy.ufunc_return %0 : f32 + } + ) """ if __name__ == "__main__": diff --git a/python/npcomp/exp/extractor.py b/python/npcomp/exp/extractor.py deleted file mode 100644 index b3d7a4dff..000000000 --- a/python/npcomp/exp/extractor.py +++ /dev/null @@ -1,199 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import re -import numpy as np - -from ..native.mlir import edsc -from ..exporter import * -from ..types import * - - -class TracingError(Exception): - pass - - -class EmitterRegistry: - def __init__(self): - self._func_emitters = {} - - def register(self, func, emitter): - self._func_emitters[func] = emitter - - def lookup(self, func): - return self._func_emitters.get(func) - - def register_ufunc(self, ufunc, function_name): - def emitter(pft, method, *inputs, **kwargs): - if method == "__call__": - if kwargs: - raise TracingError("Generic ufunc with kwargs not supported %r" % ( - ufunc,)) - - # Map inputs to TracedArrays. - # TODO: Process captures, promotions, etc. - op_inputs = [] - for py_input in inputs: - if not isinstance(py_input, TracedArray): - raise TracingError("Unsupported ufunc input: %r", (py_input,)) - op_input = pft.get_traced_array_value(py_input) - if op_input is None: - raise TracingError("Unregistered traced array: %r", (py_input,)) - op_inputs.append(op_input) - - # Emit op. - mlir_m = pft.mlir_module - op_result_types = [mlir_m.make_type("tensor<*x!numpy.any_dtype>")] - op_result = edsc.op("numpy.tmp_generic_ufunc", op_inputs, op_result_types, - ufunc_name=mlir_m.stringAttr(function_name)) - - # Wrap returns. - return_array = TracedArray(pft) - pft.set_traced_array(return_array, op_result) - return return_array - - raise TracingError("Unsupported ufunc method %r:%r" % (ufunc, method,)) - - self.register(ufunc, emitter) - - -EMITTER_REGISTRY = EmitterRegistry() -EMITTER_REGISTRY.register_ufunc(np.multiply, "numpy.multiply") -EMITTER_REGISTRY.register_ufunc(np.add, "numpy.add") - - -class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): - """An array that traces its operations.""" - def __init__(self, pft: "PyFuncTrace"): - self._pft = pft - - def __hash__(self): - return id(self) - - def __repr__(self): - return "" % id(self) - - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - emitter = EMITTER_REGISTRY.lookup(ufunc) - if emitter is None: - return NotImplemented - result = emitter(self._pft, method, *inputs, **kwargs) - return result - - -class PyFuncTrace: - r"""Creates an MLIR function from an unwrapped python function. - - # TODO: These constraints are too verbose and should be coming in by - # example. - >>> def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: - ... return a * b + a - >>> exp = Exporter() - >>> exp.simple_mul = simple_mul - >>> exp.simple_mul.sig.args["a"] += Shape(1, 4) - >>> exp.simple_mul.sig.args["a"] += DynamicDim(0) - >>> exp.simple_mul.sig.args["a"] += DType(np.float32) - >>> exp.simple_mul.sig.args["b"] += Shape(1) - >>> exp.simple_mul.sig.args["b"] += DType(np.float32) - >>> exp.simple_mul.sig.result += Shape(1, 4) - >>> exp.simple_mul.sig.result += DynamicDim(0) - >>> exp.simple_mul.sig.result += DType(np.float32) - >>> pft = PyFuncTrace(exp.simple_mul) - >>> pft.trace() - >>> print(pft.mlir_module.get_ir().strip()) - module { - func @simple_mul(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor { - %0 = "numpy.tmp_generic_ufunc"(%arg0, %arg1) {ufunc_name = "numpy.multiply"} : (tensor, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> - %1 = "numpy.tmp_generic_ufunc"(%0, %arg0) {ufunc_name = "numpy.add"} : (tensor<*x!numpy.any_dtype>, tensor) -> tensor<*x!numpy.any_dtype> - %2 = "numpy.narrow"(%1) : (tensor<*x!numpy.any_dtype>) -> tensor - return %2 : tensor - } - } - """ - __slots__ = [ - "epf", - "mlir_ctx", - "mlir_fun", - "mlir_module", - "mlir_result_types", - "_args_array_params", - "_traced_arrays", - "_python_args", - "_result_array_params", - ] - def __init__(self, epf: ExportPyFunction): - self.mlir_module = edsc.MLIRModule() - self.epf = epf - self._traced_arrays = {} # Mapping of TracedArray to current consumer value - self._validate() - - # Extract ArrayParams for all args and results. - self._args_array_params = [ - ArrayParams.from_constraints(arg.constraints) - for arg in self.epf.sig.args] - self._python_args = [None] * len(self._args_array_params) - self._result_array_params = ArrayParams.from_constraints( - self.epf.sig.result.constraints) - - # Create the MLIR function. - self.mlir_fun, self.mlir_result_types = self._create_mlir_function() - self.mlir_ctx = self.mlir_module.function_context(self.mlir_fun) - self._create_trace_roots() - - def set_traced_array(self, traced_array, value_handle): - """Sets the current SSA value for a traced_array.""" - assert isinstance(traced_array, TracedArray) - self._traced_arrays[traced_array] = value_handle - - def get_traced_array_value(self, traced_array): - return self._traced_arrays.get(traced_array) - - def trace(self): - # TODO: General argument merging - with self.mlir_ctx: - py_results = (self.epf.pyfunc(*self._python_args),) - if len(py_results) != len(self.mlir_result_types): - raise TracingError( - "Traced function returned != %d results: %r" % ( - len(self.mlir_result_types), py_results,)) - - # Narrow all results to the declared return types. - return_operands = [] - for py_result, mlir_result_type in zip(py_results, self.mlir_result_types): - mlir_result = self.get_traced_array_value(py_result) - if mlir_result is None: - raise TracingError("Unregistered traced array: %r", (py_input,)) - # narrow to declared result type. - return_operands.append(edsc.op( - "numpy.narrow", [mlir_result], [mlir_result_type])) - edsc.ret(return_operands) - - def _validate(self): - if not all(arg.type_class == TypeClass.NdArray - for arg in self.epf.sig.args): - raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,)) - if not self.epf.sig.result.type_class == TypeClass.NdArray: - raise NotImplementedError("Non NdArray result: %r" % ( - self.epf.sig.result,)) - - def _create_mlir_function(self): - mlir_m = self.mlir_module - epf = self.epf - f_args = [mlir_m.make_type(ap.mlir_tensor_type_asm) - for ap in self._args_array_params] - f_results = [mlir_m.make_type( - self._result_array_params.mlir_tensor_type_asm)] - return mlir_m.make_function(epf.__name__, f_args, f_results), f_results - - def _create_trace_roots(self): - for index, ap in enumerate(self._args_array_params): - if ap is not None: - ta = TracedArray(self) - self.set_traced_array(ta, self.mlir_fun.arg(index)) - self._python_args[index] = ta - - -if __name__ == "__main__": - import doctest - doctest.testmod() diff --git a/python/npcomp/tracing/context.py b/python/npcomp/tracing/context.py index f93d02de4..5626b6a54 100644 --- a/python/npcomp/tracing/context.py +++ b/python/npcomp/tracing/context.py @@ -10,6 +10,10 @@ import threading import numpy as np +class TracingError(Exception): + pass + + class TraceContext: """Context for intercepting array traces. @@ -42,10 +46,19 @@ class TraceContext: """ _local = threading.local() - + __slots__ = [ + "_desc", + "_next_id", + "active", + ] def __init__(self, desc=None): self._desc = desc self._next_id = 1 + self.active = False + + def _handle_ufunc(self, ufunc, method, inputs, kwargs): + """Handles a ufunc invocation involving at least one TracedArray.""" + raise NotImplementedError() def get_next_id(self): """Gets the next unique id for the context.""" @@ -78,16 +91,28 @@ class TraceContext: def __enter__(self): s = self._get_context_stack() + if s: + s[-1].active = False s.append(self) + self.active = True return self def __exit__(self, exc_type, exc_value, traceback): s = self._get_context_stack() s.pop() + self.active = False + if s: + s[-1].active = True def __repr__(self): return "" % self._desc + +def _assert_active(tc: TraceContext): + assert tc.active, ( + "Attempt to trace an action on an inactive trace context: %r" % tc) + + class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): """An array that traces its operations. @@ -103,6 +128,9 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): self._tc = tc if tc is not None else TraceContext.current() self._uid = self._tc.get_next_id() + def __hash__(self): + return id(self) + @property def uid(self): return self._uid @@ -111,7 +139,9 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): return "" % self._uid def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - return NotImplemented + tc = self._tc + _assert_active(tc) + return tc._handle_ufunc(ufunc, method, inputs, kwargs) if __name__ == "__main__": diff --git a/python/npcomp/tracing/mlir_trace.py b/python/npcomp/tracing/mlir_trace.py index bf8aa066f..05356a3ff 100644 --- a/python/npcomp/tracing/mlir_trace.py +++ b/python/npcomp/tracing/mlir_trace.py @@ -6,110 +6,184 @@ import re import numpy as np -from . import context -from ..native.mlir import edsc +from ..dialect import Numpy +from ..native.mlir import ir + +from .context import * +from ..exporter import * +from ..types import * -def _map_typing_to_mlir_type(mlir_m, typing_annot): - """Maps a typing annotation to an MLIR type. +class ModuleBuilder: + """Builds an MLIR module by tracing functions.""" + def __init__(self, mlir_context=None): + self.context = context if mlir_context else ir.MLIRContext() + # TODO: Instead of bootstrapping a large module, populate imports + # dynamically. + self.module = Numpy.load_builtin_module(self.context) + self.ops = Numpy.Ops(self.context) + self.types = Numpy.Types(self.context) - Args: - mlir_m: MLIRModule. - typing_annot: Value for an __annotations__ entry. - Returns: - MLIR type or None if not mappable. - """ - if typing_annot is np.ndarray: - return mlir_m.make_type("tensor<*x!numpy.any_dtype>") - return None + def trace(self, export_py_func: ExportPyFunction): + """Traces and exported python function.""" + assert isinstance(export_py_func, ExportPyFunction), ( + "Expected an exported python function (from the Exporter class)") + tracer = FunctionTracer(self, export_py_func) + with tracer: + tracer.trace() -class GenericFunctionTrace: - """Represents a trace of a 'generic' python function in progress.""" +class FunctionTracer(TraceContext): + """A trace of a single function.""" + __slots__ = [ + "module_builder", + "epf", + "_args_array_params", + "_f", + "_f_types", + "_mlir_m", + "_mlir_c", + "_python_args", + "_ops", + "_result_array_params", + "_traced_arrays", + "_types", + ] + def __init__(self, module_builder: ModuleBuilder, epf: ExportPyFunction): + super().__init__(desc="[trace of %s]" % epf.__name__) + self.module_builder = module_builder + self.epf = epf + self._traced_arrays = {} # Mapping of TracedArray to current consumer value + self._validate() - def __init__(self, mlir_m, mlir_f): - self._mlir_m = mlir_m - self._mlir_f = mlir_f + # Alias some parent members for convenience. + self._mlir_m = module_builder.module + self._mlir_c = module_builder.context + self._ops = module_builder.ops + self._types = module_builder.types - @property - def mlir_module(self): - return self._mlir_m + # Extract ArrayParams for all args and results. + self._args_array_params = [ + ArrayParams.from_constraints(arg.constraints) + for arg in self.epf.sig.args] + self._python_args = [None] * len(self._args_array_params) + self._result_array_params = ArrayParams.from_constraints( + self.epf.sig.result.constraints) - @property - def mlir_function(self): - return self._mlir_f + # Create the MLIR function. + self._f, self._f_types = self._create_mlir_function() + self._create_trace_roots() - @classmethod - def from_typed_pyfunc(cls, mlir_m, pyfunc, name_in_module=None): - """Creates a generic function trace from a pyfunc with type annotations. + def trace(self): + # Invoke the python function with placeholders. + # TODO: More sophisticated signature merging + # TODO: Multiple results + # TODO: Error reporting + ops = self._ops + py_results = (self.epf.pyfunc(*self._python_args),) + if len(py_results) != len(self._f_types): + raise TracingError( + "Traced function returned != %d results: %r" % ( + len(self._f_types), py_results,)) + + # Narrow all results to the declared return types. + return_operands = [] + for py_result, mlir_result_type in zip(py_results, self._f_types): + mlir_result = self.get_traced_array_value(py_result) + if mlir_result is None: + raise TracingError("Unregistered traced array: %r", (py_result,)) + # narrow to declared result type. + return_operands.extend( + ops.numpy_narrow(mlir_result_type, mlir_result).results) + ops.return_op(return_operands) - This is a relatively limited mechanism which relies on typing annotations - for arguments and results and supports a relatively limited amount of - variation. + def set_traced_array(self, traced_array, value): + """Sets the current SSA value for a traced_array.""" + assert isinstance(traced_array, TracedArray) + self._traced_arrays[traced_array] = value - Examples: + def get_traced_array_value(self, traced_array): + return self._traced_arrays.get(traced_array) - * Generic ndarrays: - >>> m = edsc.MLIRModule() - >>> def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: - ... return a * b - >>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul) - >>> ir = gft.mlir_module.get_ir() - >>> print(re.findall("func @simple_mul.+", ir)[0]) - func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) -> tensor<*x!numpy.any_dtype> attributes {py_ftype = "generic_trace", py_name = "simple_mul"} { + def _validate(self): + if not all(arg.type_class == TypeClass.NdArray + for arg in self.epf.sig.args): + raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,)) + if not self.epf.sig.result.type_class == TypeClass.NdArray: + raise NotImplementedError("Non NdArray result: %r" % ( + self.epf.sig.result,)) - * None types must be annotated: - >>> m = edsc.MLIRModule() - >>> def simple_mul(a: np.ndarray, b: np.ndarray) -> None: - ... return a * b - >>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul) - >>> ir = gft.mlir_module.get_ir() - >>> print(re.findall("func @simple_mul.+", ir)[0]) - func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) attributes {py_ftype = "generic_trace", py_name = "simple_mul"} { + def _create_mlir_function(self): + mlir_c = self._mlir_c + mlir_m = self._mlir_m + ops = self._ops + types = self._types + epf = self.epf + f_args = [mlir_c.parse_type(ap.mlir_tensor_type_asm) + for ap in self._args_array_params] + f_types = [mlir_c.parse_type( + self._result_array_params.mlir_tensor_type_asm)] + ops.builder.insert_before_terminator(mlir_m.first_block) + f_type = types.function(f_args, f_types) + f = ops.func_op(epf.__name__, f_type, create_entry_block=True) + return f, f_types - Args: - mlir_m: An MLIRModule. - pyfunc: A python function to transform. - Returns: - A new GenericFunctionTrace. - """ - if name_in_module is None: - name_in_module = pyfunc.__name__ + "$$generic" - code = pyfunc.__code__ - # Process arguments. - f_args = [] - for i in range(code.co_argcount): - arg_name = code.co_varnames[i] - arg_annot = pyfunc.__annotations__.get(arg_name) - if arg_annot is None: - raise ValueError("Function %s arg %d is missing a typing annotation" % ( - pyfunc.__name__, i)) - arg_type = _map_typing_to_mlir_type(mlir_m, arg_annot) - if arg_type is None: - raise ValueError("Function %s arg %d is not a supported type" % ( - pyfunc.__name__, i)) - arg_type = arg_type({ - "py_name": mlir_m.stringAttr(arg_name), - }) - f_args.append(arg_type) + def _create_trace_roots(self): + entry_block = self._f.first_block + for index, ap in enumerate(self._args_array_params): + if ap is not None: + ta = TracedArray(self) + self.set_traced_array(ta, entry_block.args[index]) + self._python_args[index] = ta - # Process results. - f_results = [] - if "return" not in pyfunc.__annotations__: - raise ValueError("Un-annotated function returns not yet supported") - return_annot = pyfunc.__annotations__["return"] - if return_annot is not None: - return_type = _map_typing_to_mlir_type(mlir_m, return_annot) - if return_type is None: - raise ValueError("Function %s return type %r is not supported" % ( - pyfunc.__name__, return_annot)) - f_results.append(return_type) + def _handle_ufunc(self, ufunc, method, inputs, kwargs): + if method == "__call__": + if kwargs: + raise TracingError("Generic ufunc with kwargs not supported %r" % ( + ufunc,)) + + # Map inputs to TracedArrays. + # TODO: Process captures, promotions, etc. + op_inputs = [] + for py_input in inputs: + if not isinstance(py_input, TracedArray): + raise TracingError("Unsupported ufunc input: %r", (py_input,)) + op_input = self.get_traced_array_value(py_input) + if op_input is None: + raise TracingError("Unregistered traced array: %r", (py_input,)) + op_inputs.append(op_input) + + # Emit op. + types = self._types + mlir_m = self._mlir_m + callee_symbol = _UFUNC_SYMBOL_MAP.get(ufunc) + if not callee_symbol: + raise TracingError("Unsupported ufunc: %r" % ufunc) + op_result_type = types.tensor(types.numpy_any_dtype) + call_op = self._ops.numpy_ufunc_call_op( + callee_symbol, op_result_type, *op_inputs) + op_result = call_op.results[0] + + # Wrap returns. + return_array = TracedArray(self) + self.set_traced_array(return_array, op_result) + return return_array - mlir_f = mlir_m.make_function( - name_in_module, f_args, f_results, - py_ftype=mlir_m.stringAttr("generic_trace"), - py_name=mlir_m.stringAttr(pyfunc.__name__)) - return GenericFunctionTrace(mlir_m, mlir_f) + # Unsupported method. + raise TracingError("Unsupported ufunc method %r:%r" % (ufunc, method,)) + + +# TODO: There should be an open registry of ufuncs. But for now, just map +# introspect the numpy package and record them. +def _build_ufunc_symbol_map(): + d = {} + for member in dir(np): + ufunc = getattr(np, member) + if isinstance(ufunc, np.ufunc): + d[ufunc] = "numpy." + member + return d + +_UFUNC_SYMBOL_MAP = _build_ufunc_symbol_map() if __name__ == "__main__": diff --git a/python/npcomp/tracing/mlir_trace_test.py b/python/npcomp/tracing/mlir_trace_test.py new file mode 100644 index 000000000..6ad348a3f --- /dev/null +++ b/python/npcomp/tracing/mlir_trace_test.py @@ -0,0 +1,38 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..types import * +from ..exporter import * +from .mlir_trace import * +from ..utils import test_utils + +test_utils.start_filecheck_test() + +def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a * b + a + b + +# TODO: Implement subclassing and deriving constraints by run +exp = Exporter() +exp.simple_mul = simple_mul +exp.simple_mul.sig.args["a"] += Shape(1, 4) +exp.simple_mul.sig.args["a"] += DynamicDim(0) +exp.simple_mul.sig.args["a"] += DType(np.float32) +exp.simple_mul.sig.args["b"] += Shape(1) +exp.simple_mul.sig.args["b"] += DType(np.float32) +exp.simple_mul.sig.result += Shape(1, 4) +exp.simple_mul.sig.result += DynamicDim(0) +exp.simple_mul.sig.result += DType(np.float32) + +mb = ModuleBuilder() +mb.trace(exp.simple_mul) +# CHECK: func @simple_mul(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor { +# CHECK: %0 = numpy.ufunc_call @numpy.multiply(%arg0, %arg1) : (tensor, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> +# CHECK: %1 = numpy.ufunc_call @numpy.add(%0, %arg0) : (tensor<*x!numpy.any_dtype>, tensor) -> tensor<*x!numpy.any_dtype> +# CHECK: %2 = numpy.ufunc_call @numpy.add(%1, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> +# CHECK: %3 = numpy.narrow %2 : (tensor<*x!numpy.any_dtype>) -> tensor +# CHECK: return %3 : tensor +# CHECK: } +print(mb.module.to_asm()) + +test_utils.end_filecheck_test(__file__) diff --git a/python/run_tests.py b/python/run_tests.py index 50d8affe6..fbabb7588 100755 --- a/python/run_tests.py +++ b/python/run_tests.py @@ -13,7 +13,7 @@ TEST_MODULES = ( "npcomp.tracing.mlir_trace", "npcomp.types", "npcomp.exporter", - "npcomp.exp.extractor", + "npcomp.tracing.mlir_trace_test", ) # Compute PYTHONPATH for sub processes.