torch-mlir/python/torch_mlir_e2e_test/annotations.py

71 lines
2.9 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import List, Optional, Tuple, NamedTuple
import torch
# Decorators
# Currently, these decorators are very low-level and map 1:1 with
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
# be a more elaborate Python layer which allows all the different annotations
# to be expressed conveniently and gives clearer error reports when
# the annotations aren't acceptable.
# This module is kept separate from torch_mlir.torchscript_annotations so that
# we can use this from code without C++ dependencies, which prevent us from
# interfacing the test framework across environments.
# Attribute names used for annotations.
# These should be kept in sync with their use in
# `torch_mlir/torchscript_annotations.py`.
TORCH_MLIR_EXPORT_ATTR_NAME = '_torch_mlir_export'
TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME = '_torch_mlir_arg_annotations'
def export(fn):
"""Decorator that tells the torch-mlir compiler that a method is exported.
By default, no methods are exported, which is very important for
the compiler, because otherwise most Torch programs consist of a sea
of tiny exported functions with no rank or dtype information
(see `annotate_args`), which the compiler cannot do much with.
Note that this is different from `torch.jit.export`, which controls
which methods are scripted in the first place. For non-`forward` methods,
using this decorator usually means you also need `torch.jit.export`.
Conceptually, this decorator is annotating the scripted module, but is
applied to the original `torch.nn.Module` for convenience.
"""
setattr(fn, TORCH_MLIR_EXPORT_ATTR_NAME, True)
return fn
ArgAnnotation = Tuple[List[int], torch.dtype]
# TODO: Replace with py3 extended argument annotations when available.
# See https://www.python.org/dev/peps/pep-0593/
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
"""Decorator that tells the torch-mlir compiler information about arguments.
The `annotations` should be a list of the same length as the number of
argument to the method (including `self`). Each list entry is either:
- None, corresponding to providing the compiler with no information.
- A 2-tuple consisting of a shape and a dtype, such as
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
indicated by using `-1` as the size. This provides the compiler a
guarantee that the argument will always dynamically have the described
shape and dtype.
"""
# TODO: Check the number of arguments matches the number of arg annotations.
def decorator(fn):
setattr(fn, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME, annotations)
return fn
return decorator