mirror of https://github.com/llvm/torch-mlir
21 lines
768 B
Python
21 lines
768 B
Python
|
# -*- Python -*-
|
||
|
# This file is licensed under a pytorch-style license
|
||
|
# See frontends/pytorch/LICENSE for license information.
|
||
|
#
|
||
|
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
||
|
|
||
|
from typing import Iterable, Union
|
||
|
from torch.fx import GraphModule
|
||
|
from .torch_mlir_types import TorchTensorType, PythonType
|
||
|
|
||
|
def annotate_forward_args(module: GraphModule,
|
||
|
types: Iterable[Union[TorchTensorType, type]]
|
||
|
) -> GraphModule:
|
||
|
operands = filter(lambda node: node.op == 'placeholder', module.graph.nodes)
|
||
|
for operand, type_ in zip(operands, types):
|
||
|
if isinstance(type_, type):
|
||
|
type_ = PythonType(type_)
|
||
|
operand.update_kwarg('torch_mlir_type', type_)
|
||
|
|
||
|
return module
|