Handle np.transpose() and ndarray.T shortcut.

* Just the form without explicit permutation for now.
pull/1/head
Stella Laurenzo 2020-05-04 16:20:36 -07:00
parent a5f755d406
commit ebb5bcf6af
5 changed files with 70 additions and 11 deletions

View File

@ -115,7 +115,7 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
def Numpy_DotOp : Numpy_Op<"dot", []> {
let summary = "Represents the `numpy.dot` operator";
let description = [{
See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
See: https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html
}];
let arguments = (ins
Numpy_AnyArray:$a,
@ -129,4 +129,25 @@ def Numpy_DotOp : Numpy_Op<"dot", []> {
}];
}
def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
let summary = "Represents the `numpy.transpose` op with no permutation specified";
let description = [{
This op is equivalent to calling `numpy.transpose(arr)`, which reverses
the axes of the array. It is separate from the explicit form because it
is not always possible to locallly infer an appropriate axis transform
at the point of declaration.
See: https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html
}];
let arguments = (ins
Numpy_AnyArray:$a
);
let results = (outs
Numpy_AnyArray:$output
);
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, $output)
}];
}
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS

View File

@ -154,6 +154,13 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
_assert_active(tc)
return tc._handle_array_func(func, types, args, kwargs)
@property
def T(self):
"""Shortcut for transpose."""
tc = self._tc
_assert_active(tc)
return tc._handle_array_func(np.transpose, [TracedArray], [self], {})
def _check_numpy_version():
version = np.lib.NumpyVersion(np.__version__)

View File

@ -259,9 +259,17 @@ class EmitterRegistry:
self.register_ufunc(ufunc, "__call__",
GenericCallUfuncEmitter("numpy." + member))
# Register generic 1-result array funcs.
for f, op_name in ((np.inner, "numpy.inner"), (np.outer, "numpy.outer"),
(np.dot, "numpy.dot"), (np.vdot, "numpy.vdot"),
(np.linalg.det, "numpy.linalg.det")):
GENERIC_FUNCS = (
(np.inner, "numpy.inner"),
(np.outer, "numpy.outer"),
(np.dot, "numpy.dot"),
(np.vdot, "numpy.vdot"),
(np.linalg.det, "numpy.linalg.det"),
# TODO: This needs a custom implementation to differentiate when
# axes is specified (this version will fail).
(np.transpose, "numpy.transpose"),
)
for f, op_name in GENERIC_FUNCS:
self.register_array_func(f, GenericArrayFuncEmitter(op_name))

View File

@ -28,13 +28,14 @@ class ModuleBuilder:
self.emitters = (emitter_registry
if emitter_registry else EmitterRegistry.create_default())
def trace(self, export_py_func: ExportPyFunction):
"""Traces and exported python function."""
assert isinstance(export_py_func, ExportPyFunction), (
"Expected an exported python function (from the Exporter class)")
tracer = FunctionTracer(self, export_py_func)
with tracer:
tracer.trace()
def trace(self, *export_py_funcs: ExportPyFunction):
"""Traces exported py functions."""
for export_py_func in export_py_funcs:
assert isinstance(export_py_func, ExportPyFunction), (
"Expected an exported python function (from the Exporter class)")
tracer = FunctionTracer(self, export_py_func)
with tracer:
tracer.trace()
class FunctionTracer(TraceContext):

View File

@ -0,0 +1,22 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import numpy as np
import npcomp as npc
from npcomp.types import *
def transpose_attribute(a: np.ndarray) -> np.ndarray:
return a.T
def transpose(a: np.ndarray) -> np.ndarray:
return np.transpose(a)
# TODO: Implement subclassing and deriving constraints by run
exp = npc.Exporter()
exp.transpose_attribute = transpose_attribute
exp.transpose = transpose
mb = npc.tracing.ModuleBuilder()
mb.trace(exp.transpose_attribute, exp.transpose)
print(mb.module.to_asm())