torch-mlir/python/npcomp/tracing/mlir_trace.py

118 lines
4.0 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import re
import numpy as np
from . import context
from ..native.mlir import edsc
def _map_typing_to_mlir_type(mlir_m, typing_annot):
"""Maps a typing annotation to an MLIR type.
Args:
mlir_m: MLIRModule.
typing_annot: Value for an __annotations__ entry.
Returns:
MLIR type or None if not mappable.
"""
if typing_annot is np.ndarray:
return mlir_m.make_type("tensor<*x!numpy.any_dtype>")
return None
class GenericFunctionTrace:
"""Represents a trace of a 'generic' python function in progress."""
def __init__(self, mlir_m, mlir_f):
self._mlir_m = mlir_m
self._mlir_f = mlir_f
@property
def mlir_module(self):
return self._mlir_m
@property
def mlir_function(self):
return self._mlir_f
@classmethod
def from_typed_pyfunc(cls, mlir_m, pyfunc, name_in_module=None):
"""Creates a generic function trace from a pyfunc with type annotations.
This is a relatively limited mechanism which relies on typing annotations
for arguments and results and supports a relatively limited amount of
variation.
Examples:
* Generic ndarrays:
>>> m = edsc.MLIRModule()
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
... return a * b
>>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul)
>>> ir = gft.mlir_module.get_ir()
>>> print(re.findall("func @simple_mul.+", ir)[0])
func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) -> tensor<*x!numpy.any_dtype> attributes {py_ftype = "generic_trace", py_name = "simple_mul"} {
* None types must be annotated:
>>> m = edsc.MLIRModule()
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> None:
... return a * b
>>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul)
>>> ir = gft.mlir_module.get_ir()
>>> print(re.findall("func @simple_mul.+", ir)[0])
func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) attributes {py_ftype = "generic_trace", py_name = "simple_mul"} {
Args:
mlir_m: An MLIRModule.
pyfunc: A python function to transform.
Returns:
A new GenericFunctionTrace.
"""
if name_in_module is None:
name_in_module = pyfunc.__name__ + "$$generic"
code = pyfunc.__code__
# Process arguments.
f_args = []
for i in range(code.co_argcount):
arg_name = code.co_varnames[i]
arg_annot = pyfunc.__annotations__.get(arg_name)
if arg_annot is None:
raise ValueError("Function %s arg %d is missing a typing annotation" % (
pyfunc.__name__, i))
arg_type = _map_typing_to_mlir_type(mlir_m, arg_annot)
if arg_type is None:
raise ValueError("Function %s arg %d is not a supported type" % (
pyfunc.__name__, i))
arg_type = arg_type({
"py_name": mlir_m.stringAttr(arg_name),
})
f_args.append(arg_type)
# Process results.
f_results = []
if "return" not in pyfunc.__annotations__:
raise ValueError("Un-annotated function returns not yet supported")
return_annot = pyfunc.__annotations__["return"]
if return_annot is not None:
return_type = _map_typing_to_mlir_type(mlir_m, return_annot)
if return_type is None:
raise ValueError("Function %s return type %r is not supported" % (
pyfunc.__name__, return_annot))
f_results.append(return_type)
mlir_f = mlir_m.make_function(
name_in_module, f_args, f_results,
py_ftype=mlir_m.stringAttr("generic_trace"),
py_name=mlir_m.stringAttr(pyfunc.__name__))
return GenericFunctionTrace(mlir_m, mlir_f)
if __name__ == "__main__":
import doctest
doctest.testmod()