mirror of https://github.com/llvm/torch-mlir
Add hook for __array_function__ and (failing) np.dot sample.
parent
a38a1e2850
commit
1f54838d2e
|
@ -5,6 +5,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -52,13 +53,18 @@ class TraceContext:
|
||||||
"active",
|
"active",
|
||||||
]
|
]
|
||||||
def __init__(self, desc=None):
|
def __init__(self, desc=None):
|
||||||
|
_check_numpy_version()
|
||||||
self._desc = desc
|
self._desc = desc
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
self.active = False
|
self.active = False
|
||||||
|
|
||||||
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
def _handle_ufunc(self, ufunc, method, inputs, kwargs):
|
||||||
"""Handles a ufunc invocation involving at least one TracedArray."""
|
"""Handles a ufunc invocation involving at least one TracedArray."""
|
||||||
raise NotImplementedError()
|
return NotImplemented
|
||||||
|
|
||||||
|
def _handle_array_func(self, func, types, inputs, kwargs):
|
||||||
|
"""Handles an __array_func__ hook involving at least on TracedArray."""
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
def get_next_id(self):
|
def get_next_id(self):
|
||||||
"""Gets the next unique id for the context."""
|
"""Gets the next unique id for the context."""
|
||||||
|
@ -143,6 +149,22 @@ class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
||||||
_assert_active(tc)
|
_assert_active(tc)
|
||||||
return tc._handle_ufunc(ufunc, method, inputs, kwargs)
|
return tc._handle_ufunc(ufunc, method, inputs, kwargs)
|
||||||
|
|
||||||
|
def __array_function__(self, func, types, args, kwargs):
|
||||||
|
tc = self._tc
|
||||||
|
_assert_active(tc)
|
||||||
|
return tc._handle_array_func(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_numpy_version():
|
||||||
|
version = np.lib.NumpyVersion(np.__version__)
|
||||||
|
if version < "1.16.0":
|
||||||
|
raise RuntimeError("Numpy version >= 1.16 is required")
|
||||||
|
if version > "1.17.0": return
|
||||||
|
if os.environ.get("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION") != "1":
|
||||||
|
raise RuntimeError(
|
||||||
|
"For numpy 1.16, the environment variable "
|
||||||
|
"NUMPY_EXPERIMENTAL_ARRAY_FUNCTION must equal 1")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import doctest
|
import doctest
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import npcomp as npc
|
||||||
|
from npcomp.types import *
|
||||||
|
|
||||||
|
def dot2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||||
|
return np.dot(a, b)
|
||||||
|
|
||||||
|
# TODO: Implement subclassing and deriving constraints by run
|
||||||
|
exp = npc.Exporter()
|
||||||
|
exp.dot2d = dot2d
|
||||||
|
exp.dot2d.sig.args["a"] += Shape(4, 16)
|
||||||
|
exp.dot2d.sig.args["a"] += DynamicDim(0)
|
||||||
|
exp.dot2d.sig.args["a"] += DType(np.float32)
|
||||||
|
exp.dot2d.sig.args["b"] += Shape(16,32)
|
||||||
|
exp.dot2d.sig.args["b"] += DType(np.float32)
|
||||||
|
exp.dot2d.sig.result += Shape(4, 32)
|
||||||
|
exp.dot2d.sig.result += DynamicDim(0)
|
||||||
|
exp.dot2d.sig.result += DType(np.float32)
|
||||||
|
|
||||||
|
mb = npc.tracing.ModuleBuilder()
|
||||||
|
mb.trace(exp.dot2d)
|
||||||
|
print(mb.module.to_asm())
|
Loading…
Reference in New Issue