[fx] Fix type hint for fx importer (#3066)

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
pull/3093/head
penguin_wwy 2024-04-02 08:31:43 +08:00 committed by GitHub
parent ec4cb8be44
commit 5325d3e6e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 6 deletions

View File

@ -27,6 +27,7 @@ from typing import (
Tuple,
TYPE_CHECKING,
Union,
Iterable,
)
import weakref
@ -1173,7 +1174,7 @@ class GraphNodeImporter:
func_dialect.ReturnOp(operands, loc=loc)
def import_nodes(
self, nodes: Sequence[Node], *, skip_placeholders_outputs: bool = False
self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False
):
with InsertionPoint(self._b):
loc = Location.unknown()
@ -1266,7 +1267,7 @@ class GraphNodeImporter:
(arg.meta["val"].node.pytype if isinstance(arg, Node) else type(arg))
for arg in node.args
]
is_int = [item == int for item in arg_types]
is_int = [item is int for item in arg_types]
if all(is_int):
op_overload = "int"
elif any(is_int):
@ -1546,7 +1547,7 @@ class GraphNodeImporter:
).result
def _import_list_argument(
self, loc: Location, arg: NodeArgument, expected_jit_type
self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type
) -> Value:
assert (
isinstance(expected_jit_type, torch.ListType)
@ -1554,7 +1555,7 @@ class GraphNodeImporter:
isinstance(expected_jit_type, torch.OptionalType)
and isinstance(expected_jit_type.getElementType(), torch.ListType)
)
or isinstance(expected_jit_type, NoneType)
or (expected_jit_type is None)
), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}"
# parse list type
@ -1630,7 +1631,7 @@ class GraphNodeImporter:
with loc:
return cvt(arg, self, self._cc)
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema):
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]:
return_count = len(schema.returns)
if return_count == 1:
# Unary return directly maps a single meta["val"] and cannot be subscripted.
@ -1649,7 +1650,6 @@ class GraphNodeImporter:
result_types = []
for v in node.meta["val"]:
result_types.append(self._cc.value_info_to_type(v))
result_types = tuple(result_types)
return result_types