Normalize type hints to be compatible with multiple Python versions (#3028)

Although we provide a wheel package for Python 3.8, it may actually
throw the following exception:
`TypeError: 'type' object is not subscriptable`
pull/3032/head
penguin_wwy 2024-03-15 23:29:48 +08:00 committed by GitHub
parent 4282eb9e76
commit f34c187ac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 37 deletions

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import List, Optional, Any, Tuple, Union
from typing import List, Optional, Any, Tuple, Union, Dict, Set
import argparse
import os
@ -1767,9 +1767,9 @@ _SORTED_TORCH_TYPES = [
def _check_tensors_with_the_same_dtype(
num_of_tensors: Optional[int] = None,
tensor_shapes: Optional[list[tuple[int]]] = None,
tensor_shapes: Optional[List[Tuple[int]]] = None,
tensor_device: Optional[torch.device] = None,
error_types: Optional[set[int]] = None, *args, **kwargs):
error_types: Optional[Set[int]] = None, *args, **kwargs):
"""Create invocations where all tensors have the same dtype.
This function generates invocations with `num_of_tensors` tensors
@ -1801,10 +1801,10 @@ def _check_tensors_with_the_same_dtype(
return invocations
def _check_two_tensor_op(
tensor_shapes: Optional[list[tuple[int]]] = None,
tensor_shapes: Optional[List[Tuple[int]]] = None,
tensor_device: Optional[torch.device] = None,
input_error_types: Optional[set[int]] = None,
output_error_types: Optional[set[int]] = None, **kwargs):
input_error_types: Optional[Set[int]] = None,
output_error_types: Optional[Set[int]] = None, **kwargs):
"""Generate invocations for basic two-tensor dtype functions.
This helper function is meant to be used to check dtype functions that

View File

@ -276,7 +276,7 @@ class SparsityMeta:
batch_dim: int
sparse_dim: int
dense_dim: int
blocksize: Optional[tuple[int, int]]
blocksize: Optional[Tuple[int, int]]
pos_dtype: torch.dtype
crd_dtype: torch.dtype
@ -489,8 +489,8 @@ class FxImporter:
default policy is to capture them as frozen values.
"""
# Create lookaside table of placeholders/outputs.
placeholder_nodes: dict[str, Node] = {}
all_producer_nodes: dict[str, Node] = {}
placeholder_nodes: Dict[str, Node] = {}
all_producer_nodes: Dict[str, Node] = {}
loc: Optional[Location] = None
for node in prog.graph.nodes:
if loc is None:
@ -522,15 +522,15 @@ class FxImporter:
}
# Additional bindings that we need to set up after the function is created.
mutable_buffer_target_producers: dict[str, str] = {}
constant_tensors: dict[Node, torch.Tensor] = {}
parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {}
buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {}
mutable_buffer_target_producers: Dict[str, str] = {}
constant_tensors: Dict[Node, torch.Tensor] = {}
parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
# Derive user outputs that we preserve. These will be nodes of the
# producer for the output.
user_outputs: list[Node] = []
user_output_types: list[IrType] = []
user_outputs: List[Node] = []
user_output_types: List[IrType] = []
for output_spec in sig.output_specs:
kind = output_spec.kind
arg = output_spec.arg
@ -548,8 +548,8 @@ class FxImporter:
mutable_buffer_target_producers[output_spec.target] = arg.name
# Derive user inputs. These will be op=='placeholder' nodes.
user_inputs: list[Node] = []
user_input_types: list[IrType] = []
user_inputs: List[Node] = []
user_input_types: List[IrType] = []
for input_spec in sig.input_specs:
arg = input_spec.arg
if input_spec.kind == InputKind.USER_INPUT:
@ -700,7 +700,7 @@ class FxImporter:
"""
sig = prog.graph_signature
state_dict = prog.state_dict
arg_replacements: dict[str, Any] = {}
arg_replacements: Dict[str, Any] = {}
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
@ -1003,7 +1003,7 @@ class GraphNodeImporter:
# constructs and returns a value.
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
# Map of node name to hook that should be called when it is produced.
self._on_node_produced: dict[str, Callable[[Value], None]] = {}
self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
# Statically multi-result nodes which we have de-tupled are noted here.
# They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set()
@ -1118,7 +1118,7 @@ class GraphNodeImporter:
self._on_node_produced[info.mutable_producer_node_name] = on_produced
def return_node_values(self, loc, nodes: list[Node]):
def return_node_values(self, loc, nodes: List[Node]):
with loc, InsertionPoint(self._b):
operands = [self.resolve_node_value(n) for n in nodes]
func_dialect.ReturnOp(operands, loc=loc)

View File

@ -33,7 +33,7 @@ except ModuleNotFoundError as e:
"The onnx package (`pip install onnx`) is required to use the onnx importer"
) from e
from typing import Optional
from typing import Optional, List, Dict, Tuple
from dataclasses import dataclass
@ -113,16 +113,16 @@ class GraphInfo:
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
self.model_info = model_info
self.graph_proto = graph_proto
self.initializer_map: dict[str, onnx.TensorProto] = {
self.initializer_map: Dict[str, onnx.TensorProto] = {
n.name: n for n in graph_proto.initializer
}
self.value_info_map: dict[str, onnx.ValueInfoProto] = {
self.value_info_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.value_info
}
self.declared_input_map: dict[str, onnx.ValueInfoProto] = {
self.declared_input_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.input
}
self.output_map: dict[str, onnx.ValueInfoProto] = {
self.output_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.output
}
@ -191,7 +191,7 @@ class NodeImporter:
self._gi = graph_info
self._p = parent_op
self._b = block
self._nv_map: dict[str, Value] = {}
self._nv_map: Dict[str, Value] = {}
@classmethod
def define_function(
@ -225,7 +225,7 @@ class NodeImporter:
with container_op.context:
i64_type = IntegerType.get_signed(64)
default_opset_version = 0
opset_versions: dict[str, IntegerAttr] = {}
opset_versions: Dict[str, IntegerAttr] = {}
for opset_import in m.opset_import:
if opset_import.domain:
opset_versions[opset_import.domain] = IntegerAttr.get(
@ -335,7 +335,7 @@ class NodeImporter:
for output_name, output_value in zip(output_names, custom_op.results):
self._nv_map[output_name] = output_value
def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]):
def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
attrs = {}
for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type
@ -358,14 +358,14 @@ class NodeImporter:
attrs[f"torch.onnx.{onnx_attr.name}"] = result
return attrs
def count_regions(self, onnx_attrs: list[onnx.AttributeProto]):
def count_regions(self, onnx_attrs: List[onnx.AttributeProto]):
count = 0
for onnx_attr in onnx_attrs:
if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH:
count += 1
return count
def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op):
def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op):
attr_map = {}
for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type
@ -458,10 +458,10 @@ class ContextCache:
def __init__(self, context: Context):
self._c = context
self._elem_type_map: dict[int, IrType] = {}
self._list_type_map:dict[str, IrType] = {}
self._optional_type_map:dict[str, IrType] = {}
self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {}
self._elem_type_map: Dict[int, IrType] = {}
self._list_type_map:Dict[str, IrType] = {}
self._optional_type_map:Dict[str, IrType] = {}
self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {}
def tensor_element_type(self, elem_type: int) -> IrType:
t = self._elem_type_map.get(elem_type)
@ -539,7 +539,7 @@ class ContextCache:
f"Unsupport optional element type")
def get_vtensor_type(
self, dims: tuple[Optional[int]], element_type: IrType
self, dims: Tuple[Optional[int]], element_type: IrType
) -> IrType:
key = (dims, element_type)
t = self._vtensor_type_map.get(key)

View File

@ -229,6 +229,7 @@ setup(
"build_py": CMakeBuild,
},
ext_modules=EXT_MODULES,
python_requires=">=3.8",
install_requires=INSTALL_REQUIRES,
extras_require={
"onnx": [

View File

@ -5,7 +5,7 @@
# RUN: %PYTHON %s | FileCheck %s
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Tuple, Dict
import torch
import torch.export
@ -80,7 +80,7 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
def sparse_export(
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`