diff --git a/python/npcomp/__init__.py b/python/npcomp/__init__.py index e69de29bb..f5a28b832 100644 --- a/python/npcomp/__init__.py +++ b/python/npcomp/__init__.py @@ -0,0 +1,6 @@ +# Top-level symbols. +from .exporter import * +from .types import * + +from . import tracing +from . import utils diff --git a/python/npcomp/tracing/__init__.py b/python/npcomp/tracing/__init__.py index e69de29bb..b0bc4d378 100644 --- a/python/npcomp/tracing/__init__.py +++ b/python/npcomp/tracing/__init__.py @@ -0,0 +1,3 @@ +# Module level symbols. +from .context import * +from .mlir_trace import * diff --git a/python/npcomp/utils/__init__.py b/python/npcomp/utils/__init__.py new file mode 100644 index 000000000..8ba4ae08d --- /dev/null +++ b/python/npcomp/utils/__init__.py @@ -0,0 +1 @@ +from . import test_utils as test diff --git a/python/npcomp/utils/test_utils.py b/python/npcomp/utils/test_utils.py index 7427c3585..19d43f7cf 100644 --- a/python/npcomp/utils/test_utils.py +++ b/python/npcomp/utils/test_utils.py @@ -10,6 +10,8 @@ import sys _disable_var = "NPCOMP_DISABLE_FILECHECK" _filecheck_binary_var = "FILECHECK_BINARY" +_redirect_io = None +_redirect_context = None def is_filecheck_disabled(): return _disable_var in os.environ diff --git a/python/samples/simple_ufunc.py b/python/samples/simple_ufunc.py new file mode 100644 index 000000000..8c3db3347 --- /dev/null +++ b/python/samples/simple_ufunc.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import numpy as np +import npcomp as npc +from npcomp.types import * + +def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a * b + a + b + +# TODO: Implement subclassing and deriving constraints by run +exp = npc.Exporter() +exp.simple_mul = simple_mul +exp.simple_mul.sig.args["a"] += Shape(1, 4) +exp.simple_mul.sig.args["a"] += DynamicDim(0) +exp.simple_mul.sig.args["a"] += DType(np.float32) +exp.simple_mul.sig.args["b"] += Shape(1) +exp.simple_mul.sig.args["b"] += DType(np.float32) +exp.simple_mul.sig.result += Shape(1, 4) +exp.simple_mul.sig.result += DynamicDim(0) +exp.simple_mul.sig.result += DType(np.float32) + +mb = npc.tracing.ModuleBuilder() +mb.trace(exp.simple_mul) +# CHECK: func @simple_mul(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor { +# CHECK: %0 = numpy.ufunc_call @numpy.multiply(%arg0, %arg1) : (tensor, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> +# CHECK: %1 = numpy.ufunc_call @numpy.add(%0, %arg0) : (tensor<*x!numpy.any_dtype>, tensor) -> tensor<*x!numpy.any_dtype> +# CHECK: %2 = numpy.ufunc_call @numpy.add(%1, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype> +# CHECK: %3 = numpy.narrow %2 : (tensor<*x!numpy.any_dtype>) -> tensor +# CHECK: return %3 : tensor +# CHECK: } +print(mb.module.to_asm())