Normalize type hints to be compatible with multiple Python versions (#3028)

Although we provide a wheel package for Python 3.8, it may actually
throw the following exception:
`TypeError: 'type' object is not subscriptable`
pull/3032/head
penguin_wwy 2024-03-15 23:29:48 +08:00 committed by GitHub
parent 4282eb9e76
commit f34c187ac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 37 deletions

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
from typing import List, Optional, Any, Tuple, Union from typing import List, Optional, Any, Tuple, Union, Dict, Set
import argparse import argparse
import os import os
@ -1767,9 +1767,9 @@ _SORTED_TORCH_TYPES = [
def _check_tensors_with_the_same_dtype( def _check_tensors_with_the_same_dtype(
num_of_tensors: Optional[int] = None, num_of_tensors: Optional[int] = None,
tensor_shapes: Optional[list[tuple[int]]] = None, tensor_shapes: Optional[List[Tuple[int]]] = None,
tensor_device: Optional[torch.device] = None, tensor_device: Optional[torch.device] = None,
error_types: Optional[set[int]] = None, *args, **kwargs): error_types: Optional[Set[int]] = None, *args, **kwargs):
"""Create invocations where all tensors have the same dtype. """Create invocations where all tensors have the same dtype.
This function generates invocations with `num_of_tensors` tensors This function generates invocations with `num_of_tensors` tensors
@ -1801,10 +1801,10 @@ def _check_tensors_with_the_same_dtype(
return invocations return invocations
def _check_two_tensor_op( def _check_two_tensor_op(
tensor_shapes: Optional[list[tuple[int]]] = None, tensor_shapes: Optional[List[Tuple[int]]] = None,
tensor_device: Optional[torch.device] = None, tensor_device: Optional[torch.device] = None,
input_error_types: Optional[set[int]] = None, input_error_types: Optional[Set[int]] = None,
output_error_types: Optional[set[int]] = None, **kwargs): output_error_types: Optional[Set[int]] = None, **kwargs):
"""Generate invocations for basic two-tensor dtype functions. """Generate invocations for basic two-tensor dtype functions.
This helper function is meant to be used to check dtype functions that This helper function is meant to be used to check dtype functions that

View File

@ -276,7 +276,7 @@ class SparsityMeta:
batch_dim: int batch_dim: int
sparse_dim: int sparse_dim: int
dense_dim: int dense_dim: int
blocksize: Optional[tuple[int, int]] blocksize: Optional[Tuple[int, int]]
pos_dtype: torch.dtype pos_dtype: torch.dtype
crd_dtype: torch.dtype crd_dtype: torch.dtype
@ -489,8 +489,8 @@ class FxImporter:
default policy is to capture them as frozen values. default policy is to capture them as frozen values.
""" """
# Create lookaside table of placeholders/outputs. # Create lookaside table of placeholders/outputs.
placeholder_nodes: dict[str, Node] = {} placeholder_nodes: Dict[str, Node] = {}
all_producer_nodes: dict[str, Node] = {} all_producer_nodes: Dict[str, Node] = {}
loc: Optional[Location] = None loc: Optional[Location] = None
for node in prog.graph.nodes: for node in prog.graph.nodes:
if loc is None: if loc is None:
@ -522,15 +522,15 @@ class FxImporter:
} }
# Additional bindings that we need to set up after the function is created. # Additional bindings that we need to set up after the function is created.
mutable_buffer_target_producers: dict[str, str] = {} mutable_buffer_target_producers: Dict[str, str] = {}
constant_tensors: dict[Node, torch.Tensor] = {} constant_tensors: Dict[Node, torch.Tensor] = {}
parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {} parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {} buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
# Derive user outputs that we preserve. These will be nodes of the # Derive user outputs that we preserve. These will be nodes of the
# producer for the output. # producer for the output.
user_outputs: list[Node] = [] user_outputs: List[Node] = []
user_output_types: list[IrType] = [] user_output_types: List[IrType] = []
for output_spec in sig.output_specs: for output_spec in sig.output_specs:
kind = output_spec.kind kind = output_spec.kind
arg = output_spec.arg arg = output_spec.arg
@ -548,8 +548,8 @@ class FxImporter:
mutable_buffer_target_producers[output_spec.target] = arg.name mutable_buffer_target_producers[output_spec.target] = arg.name
# Derive user inputs. These will be op=='placeholder' nodes. # Derive user inputs. These will be op=='placeholder' nodes.
user_inputs: list[Node] = [] user_inputs: List[Node] = []
user_input_types: list[IrType] = [] user_input_types: List[IrType] = []
for input_spec in sig.input_specs: for input_spec in sig.input_specs:
arg = input_spec.arg arg = input_spec.arg
if input_spec.kind == InputKind.USER_INPUT: if input_spec.kind == InputKind.USER_INPUT:
@ -700,7 +700,7 @@ class FxImporter:
""" """
sig = prog.graph_signature sig = prog.graph_signature
state_dict = prog.state_dict state_dict = prog.state_dict
arg_replacements: dict[str, Any] = {} arg_replacements: Dict[str, Any] = {}
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
@ -1003,7 +1003,7 @@ class GraphNodeImporter:
# constructs and returns a value. # constructs and returns a value.
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
# Map of node name to hook that should be called when it is produced. # Map of node name to hook that should be called when it is produced.
self._on_node_produced: dict[str, Callable[[Value], None]] = {} self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
# Statically multi-result nodes which we have de-tupled are noted here. # Statically multi-result nodes which we have de-tupled are noted here.
# They will have their getitem calls short-circuited. # They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set() self._multi_result_nodes: Set[torch_fx.Node] = set()
@ -1118,7 +1118,7 @@ class GraphNodeImporter:
self._on_node_produced[info.mutable_producer_node_name] = on_produced self._on_node_produced[info.mutable_producer_node_name] = on_produced
def return_node_values(self, loc, nodes: list[Node]): def return_node_values(self, loc, nodes: List[Node]):
with loc, InsertionPoint(self._b): with loc, InsertionPoint(self._b):
operands = [self.resolve_node_value(n) for n in nodes] operands = [self.resolve_node_value(n) for n in nodes]
func_dialect.ReturnOp(operands, loc=loc) func_dialect.ReturnOp(operands, loc=loc)

View File

@ -33,7 +33,7 @@ except ModuleNotFoundError as e:
"The onnx package (`pip install onnx`) is required to use the onnx importer" "The onnx package (`pip install onnx`) is required to use the onnx importer"
) from e ) from e
from typing import Optional from typing import Optional, List, Dict, Tuple
from dataclasses import dataclass from dataclasses import dataclass
@ -113,16 +113,16 @@ class GraphInfo:
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
self.model_info = model_info self.model_info = model_info
self.graph_proto = graph_proto self.graph_proto = graph_proto
self.initializer_map: dict[str, onnx.TensorProto] = { self.initializer_map: Dict[str, onnx.TensorProto] = {
n.name: n for n in graph_proto.initializer n.name: n for n in graph_proto.initializer
} }
self.value_info_map: dict[str, onnx.ValueInfoProto] = { self.value_info_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.value_info n.name: n for n in graph_proto.value_info
} }
self.declared_input_map: dict[str, onnx.ValueInfoProto] = { self.declared_input_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.input n.name: n for n in graph_proto.input
} }
self.output_map: dict[str, onnx.ValueInfoProto] = { self.output_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.output n.name: n for n in graph_proto.output
} }
@ -191,7 +191,7 @@ class NodeImporter:
self._gi = graph_info self._gi = graph_info
self._p = parent_op self._p = parent_op
self._b = block self._b = block
self._nv_map: dict[str, Value] = {} self._nv_map: Dict[str, Value] = {}
@classmethod @classmethod
def define_function( def define_function(
@ -225,7 +225,7 @@ class NodeImporter:
with container_op.context: with container_op.context:
i64_type = IntegerType.get_signed(64) i64_type = IntegerType.get_signed(64)
default_opset_version = 0 default_opset_version = 0
opset_versions: dict[str, IntegerAttr] = {} opset_versions: Dict[str, IntegerAttr] = {}
for opset_import in m.opset_import: for opset_import in m.opset_import:
if opset_import.domain: if opset_import.domain:
opset_versions[opset_import.domain] = IntegerAttr.get( opset_versions[opset_import.domain] = IntegerAttr.get(
@ -335,7 +335,7 @@ class NodeImporter:
for output_name, output_value in zip(output_names, custom_op.results): for output_name, output_value in zip(output_names, custom_op.results):
self._nv_map[output_name] = output_value self._nv_map[output_name] = output_value
def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
attrs = {} attrs = {}
for onnx_attr in onnx_attrs: for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type attr_type = onnx_attr.type
@ -358,14 +358,14 @@ class NodeImporter:
attrs[f"torch.onnx.{onnx_attr.name}"] = result attrs[f"torch.onnx.{onnx_attr.name}"] = result
return attrs return attrs
def count_regions(self, onnx_attrs: list[onnx.AttributeProto]): def count_regions(self, onnx_attrs: List[onnx.AttributeProto]):
count = 0 count = 0
for onnx_attr in onnx_attrs: for onnx_attr in onnx_attrs:
if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH: if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH:
count += 1 count += 1
return count return count
def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op): def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op):
attr_map = {} attr_map = {}
for onnx_attr in onnx_attrs: for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type attr_type = onnx_attr.type
@ -458,10 +458,10 @@ class ContextCache:
def __init__(self, context: Context): def __init__(self, context: Context):
self._c = context self._c = context
self._elem_type_map: dict[int, IrType] = {} self._elem_type_map: Dict[int, IrType] = {}
self._list_type_map:dict[str, IrType] = {} self._list_type_map:Dict[str, IrType] = {}
self._optional_type_map:dict[str, IrType] = {} self._optional_type_map:Dict[str, IrType] = {}
self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {}
def tensor_element_type(self, elem_type: int) -> IrType: def tensor_element_type(self, elem_type: int) -> IrType:
t = self._elem_type_map.get(elem_type) t = self._elem_type_map.get(elem_type)
@ -539,7 +539,7 @@ class ContextCache:
f"Unsupport optional element type") f"Unsupport optional element type")
def get_vtensor_type( def get_vtensor_type(
self, dims: tuple[Optional[int]], element_type: IrType self, dims: Tuple[Optional[int]], element_type: IrType
) -> IrType: ) -> IrType:
key = (dims, element_type) key = (dims, element_type)
t = self._vtensor_type_map.get(key) t = self._vtensor_type_map.get(key)

View File

@ -229,6 +229,7 @@ setup(
"build_py": CMakeBuild, "build_py": CMakeBuild,
}, },
ext_modules=EXT_MODULES, ext_modules=EXT_MODULES,
python_requires=">=3.8",
install_requires=INSTALL_REQUIRES, install_requires=INSTALL_REQUIRES,
extras_require={ extras_require={
"onnx": [ "onnx": [

View File

@ -5,7 +5,7 @@
# RUN: %PYTHON %s | FileCheck %s # RUN: %PYTHON %s | FileCheck %s
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, Tuple, Dict
import torch import torch
import torch.export import torch.export
@ -80,7 +80,7 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
def sparse_export( def sparse_export(
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram: ) -> torch.export.ExportedProgram:
""" """
This is a ***temporary*** wrapper around `torch.export.export` This is a ***temporary*** wrapper around `torch.export.export`