mirror of https://github.com/llvm/torch-mlir
Normalize type hints to be compatible with multiple Python versions (#3028)
Although we provide a wheel package for Python 3.8, it may actually throw the following exception: `TypeError: 'type' object is not subscriptable`pull/3032/head
parent
4282eb9e76
commit
f34c187ac4
|
@ -3,7 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import List, Optional, Any, Tuple, Union
|
||||
from typing import List, Optional, Any, Tuple, Union, Dict, Set
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
@ -1767,9 +1767,9 @@ _SORTED_TORCH_TYPES = [
|
|||
|
||||
def _check_tensors_with_the_same_dtype(
|
||||
num_of_tensors: Optional[int] = None,
|
||||
tensor_shapes: Optional[list[tuple[int]]] = None,
|
||||
tensor_shapes: Optional[List[Tuple[int]]] = None,
|
||||
tensor_device: Optional[torch.device] = None,
|
||||
error_types: Optional[set[int]] = None, *args, **kwargs):
|
||||
error_types: Optional[Set[int]] = None, *args, **kwargs):
|
||||
"""Create invocations where all tensors have the same dtype.
|
||||
|
||||
This function generates invocations with `num_of_tensors` tensors
|
||||
|
@ -1801,10 +1801,10 @@ def _check_tensors_with_the_same_dtype(
|
|||
return invocations
|
||||
|
||||
def _check_two_tensor_op(
|
||||
tensor_shapes: Optional[list[tuple[int]]] = None,
|
||||
tensor_shapes: Optional[List[Tuple[int]]] = None,
|
||||
tensor_device: Optional[torch.device] = None,
|
||||
input_error_types: Optional[set[int]] = None,
|
||||
output_error_types: Optional[set[int]] = None, **kwargs):
|
||||
input_error_types: Optional[Set[int]] = None,
|
||||
output_error_types: Optional[Set[int]] = None, **kwargs):
|
||||
"""Generate invocations for basic two-tensor dtype functions.
|
||||
|
||||
This helper function is meant to be used to check dtype functions that
|
||||
|
|
|
@ -276,7 +276,7 @@ class SparsityMeta:
|
|||
batch_dim: int
|
||||
sparse_dim: int
|
||||
dense_dim: int
|
||||
blocksize: Optional[tuple[int, int]]
|
||||
blocksize: Optional[Tuple[int, int]]
|
||||
pos_dtype: torch.dtype
|
||||
crd_dtype: torch.dtype
|
||||
|
||||
|
@ -489,8 +489,8 @@ class FxImporter:
|
|||
default policy is to capture them as frozen values.
|
||||
"""
|
||||
# Create lookaside table of placeholders/outputs.
|
||||
placeholder_nodes: dict[str, Node] = {}
|
||||
all_producer_nodes: dict[str, Node] = {}
|
||||
placeholder_nodes: Dict[str, Node] = {}
|
||||
all_producer_nodes: Dict[str, Node] = {}
|
||||
loc: Optional[Location] = None
|
||||
for node in prog.graph.nodes:
|
||||
if loc is None:
|
||||
|
@ -522,15 +522,15 @@ class FxImporter:
|
|||
}
|
||||
|
||||
# Additional bindings that we need to set up after the function is created.
|
||||
mutable_buffer_target_producers: dict[str, str] = {}
|
||||
constant_tensors: dict[Node, torch.Tensor] = {}
|
||||
parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {}
|
||||
buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {}
|
||||
mutable_buffer_target_producers: Dict[str, str] = {}
|
||||
constant_tensors: Dict[Node, torch.Tensor] = {}
|
||||
parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
|
||||
buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
|
||||
|
||||
# Derive user outputs that we preserve. These will be nodes of the
|
||||
# producer for the output.
|
||||
user_outputs: list[Node] = []
|
||||
user_output_types: list[IrType] = []
|
||||
user_outputs: List[Node] = []
|
||||
user_output_types: List[IrType] = []
|
||||
for output_spec in sig.output_specs:
|
||||
kind = output_spec.kind
|
||||
arg = output_spec.arg
|
||||
|
@ -548,8 +548,8 @@ class FxImporter:
|
|||
mutable_buffer_target_producers[output_spec.target] = arg.name
|
||||
|
||||
# Derive user inputs. These will be op=='placeholder' nodes.
|
||||
user_inputs: list[Node] = []
|
||||
user_input_types: list[IrType] = []
|
||||
user_inputs: List[Node] = []
|
||||
user_input_types: List[IrType] = []
|
||||
for input_spec in sig.input_specs:
|
||||
arg = input_spec.arg
|
||||
if input_spec.kind == InputKind.USER_INPUT:
|
||||
|
@ -700,7 +700,7 @@ class FxImporter:
|
|||
"""
|
||||
sig = prog.graph_signature
|
||||
state_dict = prog.state_dict
|
||||
arg_replacements: dict[str, Any] = {}
|
||||
arg_replacements: Dict[str, Any] = {}
|
||||
|
||||
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
||||
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
||||
|
@ -1003,7 +1003,7 @@ class GraphNodeImporter:
|
|||
# constructs and returns a value.
|
||||
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
|
||||
# Map of node name to hook that should be called when it is produced.
|
||||
self._on_node_produced: dict[str, Callable[[Value], None]] = {}
|
||||
self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
|
||||
# Statically multi-result nodes which we have de-tupled are noted here.
|
||||
# They will have their getitem calls short-circuited.
|
||||
self._multi_result_nodes: Set[torch_fx.Node] = set()
|
||||
|
@ -1118,7 +1118,7 @@ class GraphNodeImporter:
|
|||
|
||||
self._on_node_produced[info.mutable_producer_node_name] = on_produced
|
||||
|
||||
def return_node_values(self, loc, nodes: list[Node]):
|
||||
def return_node_values(self, loc, nodes: List[Node]):
|
||||
with loc, InsertionPoint(self._b):
|
||||
operands = [self.resolve_node_value(n) for n in nodes]
|
||||
func_dialect.ReturnOp(operands, loc=loc)
|
||||
|
|
|
@ -33,7 +33,7 @@ except ModuleNotFoundError as e:
|
|||
"The onnx package (`pip install onnx`) is required to use the onnx importer"
|
||||
) from e
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -113,16 +113,16 @@ class GraphInfo:
|
|||
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
|
||||
self.model_info = model_info
|
||||
self.graph_proto = graph_proto
|
||||
self.initializer_map: dict[str, onnx.TensorProto] = {
|
||||
self.initializer_map: Dict[str, onnx.TensorProto] = {
|
||||
n.name: n for n in graph_proto.initializer
|
||||
}
|
||||
self.value_info_map: dict[str, onnx.ValueInfoProto] = {
|
||||
self.value_info_map: Dict[str, onnx.ValueInfoProto] = {
|
||||
n.name: n for n in graph_proto.value_info
|
||||
}
|
||||
self.declared_input_map: dict[str, onnx.ValueInfoProto] = {
|
||||
self.declared_input_map: Dict[str, onnx.ValueInfoProto] = {
|
||||
n.name: n for n in graph_proto.input
|
||||
}
|
||||
self.output_map: dict[str, onnx.ValueInfoProto] = {
|
||||
self.output_map: Dict[str, onnx.ValueInfoProto] = {
|
||||
n.name: n for n in graph_proto.output
|
||||
}
|
||||
|
||||
|
@ -191,7 +191,7 @@ class NodeImporter:
|
|||
self._gi = graph_info
|
||||
self._p = parent_op
|
||||
self._b = block
|
||||
self._nv_map: dict[str, Value] = {}
|
||||
self._nv_map: Dict[str, Value] = {}
|
||||
|
||||
@classmethod
|
||||
def define_function(
|
||||
|
@ -225,7 +225,7 @@ class NodeImporter:
|
|||
with container_op.context:
|
||||
i64_type = IntegerType.get_signed(64)
|
||||
default_opset_version = 0
|
||||
opset_versions: dict[str, IntegerAttr] = {}
|
||||
opset_versions: Dict[str, IntegerAttr] = {}
|
||||
for opset_import in m.opset_import:
|
||||
if opset_import.domain:
|
||||
opset_versions[opset_import.domain] = IntegerAttr.get(
|
||||
|
@ -335,7 +335,7 @@ class NodeImporter:
|
|||
for output_name, output_value in zip(output_names, custom_op.results):
|
||||
self._nv_map[output_name] = output_value
|
||||
|
||||
def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]):
|
||||
def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
|
||||
attrs = {}
|
||||
for onnx_attr in onnx_attrs:
|
||||
attr_type = onnx_attr.type
|
||||
|
@ -358,14 +358,14 @@ class NodeImporter:
|
|||
attrs[f"torch.onnx.{onnx_attr.name}"] = result
|
||||
return attrs
|
||||
|
||||
def count_regions(self, onnx_attrs: list[onnx.AttributeProto]):
|
||||
def count_regions(self, onnx_attrs: List[onnx.AttributeProto]):
|
||||
count = 0
|
||||
for onnx_attr in onnx_attrs:
|
||||
if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op):
|
||||
def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op):
|
||||
attr_map = {}
|
||||
for onnx_attr in onnx_attrs:
|
||||
attr_type = onnx_attr.type
|
||||
|
@ -458,10 +458,10 @@ class ContextCache:
|
|||
|
||||
def __init__(self, context: Context):
|
||||
self._c = context
|
||||
self._elem_type_map: dict[int, IrType] = {}
|
||||
self._list_type_map:dict[str, IrType] = {}
|
||||
self._optional_type_map:dict[str, IrType] = {}
|
||||
self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {}
|
||||
self._elem_type_map: Dict[int, IrType] = {}
|
||||
self._list_type_map:Dict[str, IrType] = {}
|
||||
self._optional_type_map:Dict[str, IrType] = {}
|
||||
self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {}
|
||||
|
||||
def tensor_element_type(self, elem_type: int) -> IrType:
|
||||
t = self._elem_type_map.get(elem_type)
|
||||
|
@ -539,7 +539,7 @@ class ContextCache:
|
|||
f"Unsupport optional element type")
|
||||
|
||||
def get_vtensor_type(
|
||||
self, dims: tuple[Optional[int]], element_type: IrType
|
||||
self, dims: Tuple[Optional[int]], element_type: IrType
|
||||
) -> IrType:
|
||||
key = (dims, element_type)
|
||||
t = self._vtensor_type_map.get(key)
|
||||
|
|
1
setup.py
1
setup.py
|
@ -229,6 +229,7 @@ setup(
|
|||
"build_py": CMakeBuild,
|
||||
},
|
||||
ext_modules=EXT_MODULES,
|
||||
python_requires=">=3.8",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"onnx": [
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Tuple, Dict
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
|
@ -80,7 +80,7 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
|
|||
|
||||
|
||||
def sparse_export(
|
||||
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None
|
||||
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> torch.export.ExportedProgram:
|
||||
"""
|
||||
This is a ***temporary*** wrapper around `torch.export.export`
|
||||
|
|
Loading…
Reference in New Issue