mirror of https://github.com/llvm/torch-mlir
Handle np.transpose() and ndarray.T shortcut.
* Just the form without explicit permutation for now.pull/1/head
parent
a5f755d406
commit
ebb5bcf6af
|
@ -115,7 +115,7 @@ def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
|
|||
def Numpy_DotOp : Numpy_Op<"dot", []> {
|
||||
let summary = "Represents the `numpy.dot` operator";
|
||||
let description = [{
|
||||
See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
|
||||
See: https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html
|
||||
}];
|
||||
let arguments = (ins
|
||||
Numpy_AnyArray:$a,
|
||||
|
@ -129,4 +129,25 @@ def Numpy_DotOp : Numpy_Op<"dot", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Numpy_TransposeOp : Numpy_Op<"transpose", []> {
|
||||
let summary = "Represents the `numpy.transpose` op with no permutation specified";
|
||||
let description = [{
|
||||
This op is equivalent to calling `numpy.transpose(arr)`, which reverses
|
||||
the axes of the array. It is separate from the explicit form because it
|
||||
is not always possible to locallly infer an appropriate axis transform
|
||||
at the point of declaration.
|
||||
|
||||
See: https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html
|
||||
}];
|
||||
let arguments = (ins
|
||||
Numpy_AnyArray:$a
|
||||
);
|
||||
let results = (outs
|
||||
Numpy_AnyArray:$output
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
operands attr-dict `:` functional-type(operands, $output)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||
|
|
|
@ -154,6 +154,13 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
|||
_assert_active(tc)
|
||||
return tc._handle_array_func(func, types, args, kwargs)
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
"""Shortcut for transpose."""
|
||||
tc = self._tc
|
||||
_assert_active(tc)
|
||||
return tc._handle_array_func(np.transpose, [TracedArray], [self], {})
|
||||
|
||||
|
||||
def _check_numpy_version():
|
||||
version = np.lib.NumpyVersion(np.__version__)
|
||||
|
|
|
@ -259,9 +259,17 @@ class EmitterRegistry:
|
|||
self.register_ufunc(ufunc, "__call__",
|
||||
GenericCallUfuncEmitter("numpy." + member))
|
||||
# Register generic 1-result array funcs.
|
||||
for f, op_name in ((np.inner, "numpy.inner"), (np.outer, "numpy.outer"),
|
||||
(np.dot, "numpy.dot"), (np.vdot, "numpy.vdot"),
|
||||
(np.linalg.det, "numpy.linalg.det")):
|
||||
GENERIC_FUNCS = (
|
||||
(np.inner, "numpy.inner"),
|
||||
(np.outer, "numpy.outer"),
|
||||
(np.dot, "numpy.dot"),
|
||||
(np.vdot, "numpy.vdot"),
|
||||
(np.linalg.det, "numpy.linalg.det"),
|
||||
# TODO: This needs a custom implementation to differentiate when
|
||||
# axes is specified (this version will fail).
|
||||
(np.transpose, "numpy.transpose"),
|
||||
)
|
||||
for f, op_name in GENERIC_FUNCS:
|
||||
self.register_array_func(f, GenericArrayFuncEmitter(op_name))
|
||||
|
||||
|
||||
|
|
|
@ -28,8 +28,9 @@ class ModuleBuilder:
|
|||
self.emitters = (emitter_registry
|
||||
if emitter_registry else EmitterRegistry.create_default())
|
||||
|
||||
def trace(self, export_py_func: ExportPyFunction):
|
||||
"""Traces and exported python function."""
|
||||
def trace(self, *export_py_funcs: ExportPyFunction):
|
||||
"""Traces exported py functions."""
|
||||
for export_py_func in export_py_funcs:
|
||||
assert isinstance(export_py_func, ExportPyFunction), (
|
||||
"Expected an exported python function (from the Exporter class)")
|
||||
tracer = FunctionTracer(self, export_py_func)
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import numpy as np
|
||||
import npcomp as npc
|
||||
from npcomp.types import *
|
||||
|
||||
def transpose_attribute(a: np.ndarray) -> np.ndarray:
|
||||
return a.T
|
||||
|
||||
def transpose(a: np.ndarray) -> np.ndarray:
|
||||
return np.transpose(a)
|
||||
|
||||
# TODO: Implement subclassing and deriving constraints by run
|
||||
exp = npc.Exporter()
|
||||
exp.transpose_attribute = transpose_attribute
|
||||
exp.transpose = transpose
|
||||
|
||||
mb = npc.tracing.ModuleBuilder()
|
||||
mb.trace(exp.transpose_attribute, exp.transpose)
|
||||
print(mb.module.to_asm())
|
Loading…
Reference in New Issue