2021-09-29 08:25:06 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2021-09-30 00:03:40 +08:00
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
2021-09-29 08:25:06 +08:00
|
|
|
#
|
|
|
|
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
|
|
|
|
|
|
|
from typing import Iterable, Union
|
|
|
|
from torch.fx import GraphModule
|
|
|
|
from torch_mlir import ir
|
|
|
|
from torch_mlir.dialects import builtin
|
|
|
|
from .torch_mlir_types import TorchTensorType, PythonType
|
|
|
|
|
|
|
|
class Annotation:
|
|
|
|
def __init__(self, types: Iterable[Union[TorchTensorType, type]]):
|
|
|
|
self.types = list(map(lambda t:
|
|
|
|
PythonType(t) if isinstance(t, type) else t,
|
|
|
|
types))
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
result = f'Annotation instance with {len(self.types)} types\n'
|
|
|
|
for e, type_ in enumerate(self.types):
|
|
|
|
result += f' Type of argument {e + 1}: {str(type_)}\n'
|
|
|
|
return result
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
return iter(self.types)
|
|
|
|
|
|
|
|
|
|
|
|
class AnnotationConverter:
|
|
|
|
@staticmethod
|
|
|
|
def to_mlir_array_attr(annotation: Annotation,
|
|
|
|
context: ir.Context) -> ir.ArrayAttr:
|
|
|
|
dict_attrs = []
|
|
|
|
for type_ in annotation:
|
|
|
|
if not isinstance(type_, TorchTensorType):
|
|
|
|
dict_attrs.append(ir.DictAttr.get({}, context=context))
|
|
|
|
continue
|
|
|
|
|
|
|
|
ir_type = type_.to_mlir(context)
|
|
|
|
with context:
|
|
|
|
type_attr = ir.TypeAttr.get(ir_type)
|
|
|
|
dict_attr = ir.DictAttr.get({'torch.type_bound': type_attr})
|
|
|
|
dict_attrs.append(dict_attr)
|
|
|
|
|
|
|
|
return ir.ArrayAttr.get(dict_attrs, context=context)
|
|
|
|
|
|
|
|
|
|
|
|
def annotate_forward_args(module: GraphModule,
|
|
|
|
types: Iterable[Union[TorchTensorType, type]]
|
|
|
|
) -> GraphModule:
|
|
|
|
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
|
|
|
|
for operand, type_ in zip(operands, types):
|
|
|
|
if isinstance(type_, type):
|
|
|
|
type_ = PythonType(type_)
|
|
|
|
operand.update_kwarg('torch_mlir_type', type_)
|
|
|
|
|
|
|
|
return module
|