mirror of https://github.com/llvm/torch-mlir
[fx] Fix type hint for fx importer (#3066)
Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>pull/3093/head
parent
ec4cb8be44
commit
5325d3e6e6
|
@ -27,6 +27,7 @@ from typing import (
|
|||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
Iterable,
|
||||
)
|
||||
import weakref
|
||||
|
||||
|
@ -1173,7 +1174,7 @@ class GraphNodeImporter:
|
|||
func_dialect.ReturnOp(operands, loc=loc)
|
||||
|
||||
def import_nodes(
|
||||
self, nodes: Sequence[Node], *, skip_placeholders_outputs: bool = False
|
||||
self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False
|
||||
):
|
||||
with InsertionPoint(self._b):
|
||||
loc = Location.unknown()
|
||||
|
@ -1266,7 +1267,7 @@ class GraphNodeImporter:
|
|||
(arg.meta["val"].node.pytype if isinstance(arg, Node) else type(arg))
|
||||
for arg in node.args
|
||||
]
|
||||
is_int = [item == int for item in arg_types]
|
||||
is_int = [item is int for item in arg_types]
|
||||
if all(is_int):
|
||||
op_overload = "int"
|
||||
elif any(is_int):
|
||||
|
@ -1546,7 +1547,7 @@ class GraphNodeImporter:
|
|||
).result
|
||||
|
||||
def _import_list_argument(
|
||||
self, loc: Location, arg: NodeArgument, expected_jit_type
|
||||
self, loc: Location, arg: Sequence[NodeArgument], expected_jit_type
|
||||
) -> Value:
|
||||
assert (
|
||||
isinstance(expected_jit_type, torch.ListType)
|
||||
|
@ -1554,7 +1555,7 @@ class GraphNodeImporter:
|
|||
isinstance(expected_jit_type, torch.OptionalType)
|
||||
and isinstance(expected_jit_type.getElementType(), torch.ListType)
|
||||
)
|
||||
or isinstance(expected_jit_type, NoneType)
|
||||
or (expected_jit_type is None)
|
||||
), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}"
|
||||
|
||||
# parse list type
|
||||
|
@ -1630,7 +1631,7 @@ class GraphNodeImporter:
|
|||
with loc:
|
||||
return cvt(arg, self, self._cc)
|
||||
|
||||
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema):
|
||||
def _unpack_node_result_types(self, node: torch.fx.Node, schema: FunctionSchema) -> List[IrType]:
|
||||
return_count = len(schema.returns)
|
||||
if return_count == 1:
|
||||
# Unary return directly maps a single meta["val"] and cannot be subscripted.
|
||||
|
@ -1649,7 +1650,6 @@ class GraphNodeImporter:
|
|||
result_types = []
|
||||
for v in node.meta["val"]:
|
||||
result_types.append(self._cc.value_info_to_type(v))
|
||||
result_types = tuple(result_types)
|
||||
return result_types
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue