mirror of https://github.com/llvm/torch-mlir
71 lines
2.9 KiB
Python
71 lines
2.9 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
from typing import List, Optional, Tuple, NamedTuple
|
|
|
|
import torch
|
|
|
|
# Decorators
|
|
|
|
# Currently, these decorators are very low-level and map 1:1 with
|
|
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
|
|
# be a more elaborate Python layer which allows all the different annotations
|
|
# to be expressed conveniently and gives clearer error reports when
|
|
# the annotations aren't acceptable.
|
|
|
|
# This module is kept separate from torch_mlir.torchscript_annotations so that
|
|
# we can use this from code without C++ dependencies, which prevent us from
|
|
# interfacing the test framework across environments.
|
|
|
|
# Attribute names used for annotations.
|
|
# These should be kept in sync with their use in
|
|
# `torch_mlir/torchscript_annotations.py`.
|
|
TORCH_MLIR_EXPORT_ATTR_NAME = '_torch_mlir_export'
|
|
TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = '_torch_mlir_arg_annotations'
|
|
|
|
|
|
def export(fn):
|
|
"""Decorator that tells the torch-mlir compiler that a method is exported.
|
|
|
|
By default, no methods are exported, which is very important for
|
|
the compiler, because otherwise most Torch programs consist of a sea
|
|
of tiny exported functions with no rank or dtype information
|
|
(see `annotate_args`), which the compiler cannot do much with.
|
|
|
|
Note that this is different from `torch.jit.export`, which controls
|
|
which methods are scripted in the first place. For non-`forward` methods,
|
|
using this decorator usually means you also need `torch.jit.export`.
|
|
Conceptually, this decorator is annotating the scripted module, but is
|
|
applied to the original `torch.nn.Module` for convenience.
|
|
"""
|
|
setattr(fn, TORCH_MLIR_EXPORT_ATTR_NAME, True)
|
|
return fn
|
|
|
|
|
|
ArgAnnotation = Tuple[List[int], torch.dtype]
|
|
|
|
|
|
# TODO: Replace with py3 extended argument annotations when available.
|
|
# See https://www.python.org/dev/peps/pep-0593/
|
|
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
|
"""Decorator that tells the torch-mlir compiler information about arguments.
|
|
|
|
The `annotations` should be a list of the same length as the number of
|
|
argument to the method (including `self`). Each list entry is either:
|
|
- None, corresponding to providing the compiler with no information.
|
|
- A 2-tuple consisting of a shape and a dtype, such as
|
|
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
|
|
indicated by using `-1` as the size. This provides the compiler a
|
|
guarantee that the argument will always dynamically have the described
|
|
shape and dtype.
|
|
"""
|
|
|
|
# TODO: Check the number of arguments matches the number of arg annotations.
|
|
def decorator(fn):
|
|
setattr(fn, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME, annotations)
|
|
return fn
|
|
|
|
return decorator
|