diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.td b/include/npcomp/Dialect/Numpy/NumpyOps.td index 5bea1de0c..8768366ce 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.td +++ b/include/npcomp/Dialect/Numpy/NumpyOps.td @@ -115,7 +115,7 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> { def Numpy_DotOp : Numpy_Op<"dot", []> { let summary = "Represents the `numpy.dot` operator"; let description = [{ - See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html + See: https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html }]; let arguments = (ins Numpy_AnyArray:$a, @@ -129,4 +129,25 @@ def Numpy_DotOp : Numpy_Op<"dot", []> { }]; } +def Numpy_TransposeOp : Numpy_Op<"transpose", []> { + let summary = "Represents the `numpy.transpose` op with no permutation specified"; + let description = [{ + This op is equivalent to calling `numpy.transpose(arr)`, which reverses + the axes of the array. It is separate from the explicit form because it + is not always possible to locallly infer an appropriate axis transform + at the point of declaration. + + See: https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html + }]; + let arguments = (ins + Numpy_AnyArray:$a + ); + let results = (outs + Numpy_AnyArray:$output + ); + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, $output) + }]; +} + #endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS diff --git a/python/npcomp/tracing/context.py b/python/npcomp/tracing/context.py index 441637839..4cdf3a967 100644 --- a/python/npcomp/tracing/context.py +++ b/python/npcomp/tracing/context.py @@ -154,6 +154,13 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin): _assert_active(tc) return tc._handle_array_func(func, types, args, kwargs) + @property + def T(self): + """Shortcut for transpose.""" + tc = self._tc + _assert_active(tc) + return tc._handle_array_func(np.transpose, [TracedArray], [self], {}) + def _check_numpy_version(): version = np.lib.NumpyVersion(np.__version__) diff --git a/python/npcomp/tracing/emitters.py b/python/npcomp/tracing/emitters.py index 627fed7eb..9545910fc 100644 --- a/python/npcomp/tracing/emitters.py +++ b/python/npcomp/tracing/emitters.py @@ -259,9 +259,17 @@ class EmitterRegistry: self.register_ufunc(ufunc, "__call__", GenericCallUfuncEmitter("numpy." + member)) # Register generic 1-result array funcs. - for f, op_name in ((np.inner, "numpy.inner"), (np.outer, "numpy.outer"), - (np.dot, "numpy.dot"), (np.vdot, "numpy.vdot"), - (np.linalg.det, "numpy.linalg.det")): + GENERIC_FUNCS = ( + (np.inner, "numpy.inner"), + (np.outer, "numpy.outer"), + (np.dot, "numpy.dot"), + (np.vdot, "numpy.vdot"), + (np.linalg.det, "numpy.linalg.det"), + # TODO: This needs a custom implementation to differentiate when + # axes is specified (this version will fail). + (np.transpose, "numpy.transpose"), + ) + for f, op_name in GENERIC_FUNCS: self.register_array_func(f, GenericArrayFuncEmitter(op_name)) diff --git a/python/npcomp/tracing/mlir_trace.py b/python/npcomp/tracing/mlir_trace.py index 8b18d5915..ea6410276 100644 --- a/python/npcomp/tracing/mlir_trace.py +++ b/python/npcomp/tracing/mlir_trace.py @@ -28,13 +28,14 @@ class ModuleBuilder: self.emitters = (emitter_registry if emitter_registry else EmitterRegistry.create_default()) - def trace(self, export_py_func: ExportPyFunction): - """Traces and exported python function.""" - assert isinstance(export_py_func, ExportPyFunction), ( - "Expected an exported python function (from the Exporter class)") - tracer = FunctionTracer(self, export_py_func) - with tracer: - tracer.trace() + def trace(self, *export_py_funcs: ExportPyFunction): + """Traces exported py functions.""" + for export_py_func in export_py_funcs: + assert isinstance(export_py_func, ExportPyFunction), ( + "Expected an exported python function (from the Exporter class)") + tracer = FunctionTracer(self, export_py_func) + with tracer: + tracer.trace() class FunctionTracer(TraceContext): diff --git a/python/samples/transpose.py b/python/samples/transpose.py new file mode 100644 index 000000000..e17dda5c5 --- /dev/null +++ b/python/samples/transpose.py @@ -0,0 +1,22 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import numpy as np +import npcomp as npc +from npcomp.types import * + +def transpose_attribute(a: np.ndarray) -> np.ndarray: + return a.T + +def transpose(a: np.ndarray) -> np.ndarray: + return np.transpose(a) + +# TODO: Implement subclassing and deriving constraints by run +exp = npc.Exporter() +exp.transpose_attribute = transpose_attribute +exp.transpose = transpose + +mb = npc.tracing.ModuleBuilder() +mb.trace(exp.transpose_attribute, exp.transpose) +print(mb.module.to_asm())