mirror of https://github.com/llvm/torch-mlir
118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
import re
|
|
|
|
import numpy as np
|
|
|
|
from . import context
|
|
from ..native.mlir import edsc
|
|
|
|
|
|
def _map_typing_to_mlir_type(mlir_m, typing_annot):
|
|
"""Maps a typing annotation to an MLIR type.
|
|
|
|
Args:
|
|
mlir_m: MLIRModule.
|
|
typing_annot: Value for an __annotations__ entry.
|
|
Returns:
|
|
MLIR type or None if not mappable.
|
|
"""
|
|
if typing_annot is np.ndarray:
|
|
return mlir_m.make_type("tensor<*x!numpy.any_dtype>")
|
|
return None
|
|
|
|
|
|
class GenericFunctionTrace:
|
|
"""Represents a trace of a 'generic' python function in progress."""
|
|
|
|
def __init__(self, mlir_m, mlir_f):
|
|
self._mlir_m = mlir_m
|
|
self._mlir_f = mlir_f
|
|
|
|
@property
|
|
def mlir_module(self):
|
|
return self._mlir_m
|
|
|
|
@property
|
|
def mlir_function(self):
|
|
return self._mlir_f
|
|
|
|
@classmethod
|
|
def from_typed_pyfunc(cls, mlir_m, pyfunc, name_in_module=None):
|
|
"""Creates a generic function trace from a pyfunc with type annotations.
|
|
|
|
This is a relatively limited mechanism which relies on typing annotations
|
|
for arguments and results and supports a relatively limited amount of
|
|
variation.
|
|
|
|
Examples:
|
|
|
|
* Generic ndarrays:
|
|
>>> m = edsc.MLIRModule()
|
|
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
... return a * b
|
|
>>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul)
|
|
>>> ir = gft.mlir_module.get_ir()
|
|
>>> print(re.findall("func @simple_mul.+", ir)[0])
|
|
func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) -> tensor<*x!numpy.any_dtype> attributes {py_ftype = "generic_trace", py_name = "simple_mul"} {
|
|
|
|
* None types must be annotated:
|
|
>>> m = edsc.MLIRModule()
|
|
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> None:
|
|
... return a * b
|
|
>>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul)
|
|
>>> ir = gft.mlir_module.get_ir()
|
|
>>> print(re.findall("func @simple_mul.+", ir)[0])
|
|
func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) attributes {py_ftype = "generic_trace", py_name = "simple_mul"} {
|
|
|
|
Args:
|
|
mlir_m: An MLIRModule.
|
|
pyfunc: A python function to transform.
|
|
Returns:
|
|
A new GenericFunctionTrace.
|
|
"""
|
|
if name_in_module is None:
|
|
name_in_module = pyfunc.__name__ + "$$generic"
|
|
code = pyfunc.__code__
|
|
# Process arguments.
|
|
f_args = []
|
|
for i in range(code.co_argcount):
|
|
arg_name = code.co_varnames[i]
|
|
arg_annot = pyfunc.__annotations__.get(arg_name)
|
|
if arg_annot is None:
|
|
raise ValueError("Function %s arg %d is missing a typing annotation" % (
|
|
pyfunc.__name__, i))
|
|
arg_type = _map_typing_to_mlir_type(mlir_m, arg_annot)
|
|
if arg_type is None:
|
|
raise ValueError("Function %s arg %d is not a supported type" % (
|
|
pyfunc.__name__, i))
|
|
arg_type = arg_type({
|
|
"py_name": mlir_m.stringAttr(arg_name),
|
|
})
|
|
f_args.append(arg_type)
|
|
|
|
# Process results.
|
|
f_results = []
|
|
if "return" not in pyfunc.__annotations__:
|
|
raise ValueError("Un-annotated function returns not yet supported")
|
|
return_annot = pyfunc.__annotations__["return"]
|
|
if return_annot is not None:
|
|
return_type = _map_typing_to_mlir_type(mlir_m, return_annot)
|
|
if return_type is None:
|
|
raise ValueError("Function %s return type %r is not supported" % (
|
|
pyfunc.__name__, return_annot))
|
|
f_results.append(return_type)
|
|
|
|
mlir_f = mlir_m.make_function(
|
|
name_in_module, f_args, f_results,
|
|
py_ftype=mlir_m.stringAttr("generic_trace"),
|
|
py_name=mlir_m.stringAttr(pyfunc.__name__))
|
|
return GenericFunctionTrace(mlir_m, mlir_f)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import doctest
|
|
doctest.testmod()
|