mirror of https://github.com/llvm/torch-mlir
Normalize type hints to be compatible with multiple Python versions (#3028)
Although we provide a wheel package for Python 3.8, it may actually throw the following exception: `TypeError: 'type' object is not subscriptable`pull/3032/head
parent
4282eb9e76
commit
f34c187ac4
|
@ -3,7 +3,7 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
from typing import List, Optional, Any, Tuple, Union
|
from typing import List, Optional, Any, Tuple, Union, Dict, Set
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -1767,9 +1767,9 @@ _SORTED_TORCH_TYPES = [
|
||||||
|
|
||||||
def _check_tensors_with_the_same_dtype(
|
def _check_tensors_with_the_same_dtype(
|
||||||
num_of_tensors: Optional[int] = None,
|
num_of_tensors: Optional[int] = None,
|
||||||
tensor_shapes: Optional[list[tuple[int]]] = None,
|
tensor_shapes: Optional[List[Tuple[int]]] = None,
|
||||||
tensor_device: Optional[torch.device] = None,
|
tensor_device: Optional[torch.device] = None,
|
||||||
error_types: Optional[set[int]] = None, *args, **kwargs):
|
error_types: Optional[Set[int]] = None, *args, **kwargs):
|
||||||
"""Create invocations where all tensors have the same dtype.
|
"""Create invocations where all tensors have the same dtype.
|
||||||
|
|
||||||
This function generates invocations with `num_of_tensors` tensors
|
This function generates invocations with `num_of_tensors` tensors
|
||||||
|
@ -1801,10 +1801,10 @@ def _check_tensors_with_the_same_dtype(
|
||||||
return invocations
|
return invocations
|
||||||
|
|
||||||
def _check_two_tensor_op(
|
def _check_two_tensor_op(
|
||||||
tensor_shapes: Optional[list[tuple[int]]] = None,
|
tensor_shapes: Optional[List[Tuple[int]]] = None,
|
||||||
tensor_device: Optional[torch.device] = None,
|
tensor_device: Optional[torch.device] = None,
|
||||||
input_error_types: Optional[set[int]] = None,
|
input_error_types: Optional[Set[int]] = None,
|
||||||
output_error_types: Optional[set[int]] = None, **kwargs):
|
output_error_types: Optional[Set[int]] = None, **kwargs):
|
||||||
"""Generate invocations for basic two-tensor dtype functions.
|
"""Generate invocations for basic two-tensor dtype functions.
|
||||||
|
|
||||||
This helper function is meant to be used to check dtype functions that
|
This helper function is meant to be used to check dtype functions that
|
||||||
|
|
|
@ -276,7 +276,7 @@ class SparsityMeta:
|
||||||
batch_dim: int
|
batch_dim: int
|
||||||
sparse_dim: int
|
sparse_dim: int
|
||||||
dense_dim: int
|
dense_dim: int
|
||||||
blocksize: Optional[tuple[int, int]]
|
blocksize: Optional[Tuple[int, int]]
|
||||||
pos_dtype: torch.dtype
|
pos_dtype: torch.dtype
|
||||||
crd_dtype: torch.dtype
|
crd_dtype: torch.dtype
|
||||||
|
|
||||||
|
@ -489,8 +489,8 @@ class FxImporter:
|
||||||
default policy is to capture them as frozen values.
|
default policy is to capture them as frozen values.
|
||||||
"""
|
"""
|
||||||
# Create lookaside table of placeholders/outputs.
|
# Create lookaside table of placeholders/outputs.
|
||||||
placeholder_nodes: dict[str, Node] = {}
|
placeholder_nodes: Dict[str, Node] = {}
|
||||||
all_producer_nodes: dict[str, Node] = {}
|
all_producer_nodes: Dict[str, Node] = {}
|
||||||
loc: Optional[Location] = None
|
loc: Optional[Location] = None
|
||||||
for node in prog.graph.nodes:
|
for node in prog.graph.nodes:
|
||||||
if loc is None:
|
if loc is None:
|
||||||
|
@ -522,15 +522,15 @@ class FxImporter:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Additional bindings that we need to set up after the function is created.
|
# Additional bindings that we need to set up after the function is created.
|
||||||
mutable_buffer_target_producers: dict[str, str] = {}
|
mutable_buffer_target_producers: Dict[str, str] = {}
|
||||||
constant_tensors: dict[Node, torch.Tensor] = {}
|
constant_tensors: Dict[Node, torch.Tensor] = {}
|
||||||
parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {}
|
parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
|
||||||
buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {}
|
buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
|
||||||
|
|
||||||
# Derive user outputs that we preserve. These will be nodes of the
|
# Derive user outputs that we preserve. These will be nodes of the
|
||||||
# producer for the output.
|
# producer for the output.
|
||||||
user_outputs: list[Node] = []
|
user_outputs: List[Node] = []
|
||||||
user_output_types: list[IrType] = []
|
user_output_types: List[IrType] = []
|
||||||
for output_spec in sig.output_specs:
|
for output_spec in sig.output_specs:
|
||||||
kind = output_spec.kind
|
kind = output_spec.kind
|
||||||
arg = output_spec.arg
|
arg = output_spec.arg
|
||||||
|
@ -548,8 +548,8 @@ class FxImporter:
|
||||||
mutable_buffer_target_producers[output_spec.target] = arg.name
|
mutable_buffer_target_producers[output_spec.target] = arg.name
|
||||||
|
|
||||||
# Derive user inputs. These will be op=='placeholder' nodes.
|
# Derive user inputs. These will be op=='placeholder' nodes.
|
||||||
user_inputs: list[Node] = []
|
user_inputs: List[Node] = []
|
||||||
user_input_types: list[IrType] = []
|
user_input_types: List[IrType] = []
|
||||||
for input_spec in sig.input_specs:
|
for input_spec in sig.input_specs:
|
||||||
arg = input_spec.arg
|
arg = input_spec.arg
|
||||||
if input_spec.kind == InputKind.USER_INPUT:
|
if input_spec.kind == InputKind.USER_INPUT:
|
||||||
|
@ -700,7 +700,7 @@ class FxImporter:
|
||||||
"""
|
"""
|
||||||
sig = prog.graph_signature
|
sig = prog.graph_signature
|
||||||
state_dict = prog.state_dict
|
state_dict = prog.state_dict
|
||||||
arg_replacements: dict[str, Any] = {}
|
arg_replacements: Dict[str, Any] = {}
|
||||||
|
|
||||||
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
||||||
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
||||||
|
@ -1003,7 +1003,7 @@ class GraphNodeImporter:
|
||||||
# constructs and returns a value.
|
# constructs and returns a value.
|
||||||
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
|
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
|
||||||
# Map of node name to hook that should be called when it is produced.
|
# Map of node name to hook that should be called when it is produced.
|
||||||
self._on_node_produced: dict[str, Callable[[Value], None]] = {}
|
self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
|
||||||
# Statically multi-result nodes which we have de-tupled are noted here.
|
# Statically multi-result nodes which we have de-tupled are noted here.
|
||||||
# They will have their getitem calls short-circuited.
|
# They will have their getitem calls short-circuited.
|
||||||
self._multi_result_nodes: Set[torch_fx.Node] = set()
|
self._multi_result_nodes: Set[torch_fx.Node] = set()
|
||||||
|
@ -1118,7 +1118,7 @@ class GraphNodeImporter:
|
||||||
|
|
||||||
self._on_node_produced[info.mutable_producer_node_name] = on_produced
|
self._on_node_produced[info.mutable_producer_node_name] = on_produced
|
||||||
|
|
||||||
def return_node_values(self, loc, nodes: list[Node]):
|
def return_node_values(self, loc, nodes: List[Node]):
|
||||||
with loc, InsertionPoint(self._b):
|
with loc, InsertionPoint(self._b):
|
||||||
operands = [self.resolve_node_value(n) for n in nodes]
|
operands = [self.resolve_node_value(n) for n in nodes]
|
||||||
func_dialect.ReturnOp(operands, loc=loc)
|
func_dialect.ReturnOp(operands, loc=loc)
|
||||||
|
|
|
@ -33,7 +33,7 @@ except ModuleNotFoundError as e:
|
||||||
"The onnx package (`pip install onnx`) is required to use the onnx importer"
|
"The onnx package (`pip install onnx`) is required to use the onnx importer"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict, Tuple
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@ -113,16 +113,16 @@ class GraphInfo:
|
||||||
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
|
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
|
||||||
self.model_info = model_info
|
self.model_info = model_info
|
||||||
self.graph_proto = graph_proto
|
self.graph_proto = graph_proto
|
||||||
self.initializer_map: dict[str, onnx.TensorProto] = {
|
self.initializer_map: Dict[str, onnx.TensorProto] = {
|
||||||
n.name: n for n in graph_proto.initializer
|
n.name: n for n in graph_proto.initializer
|
||||||
}
|
}
|
||||||
self.value_info_map: dict[str, onnx.ValueInfoProto] = {
|
self.value_info_map: Dict[str, onnx.ValueInfoProto] = {
|
||||||
n.name: n for n in graph_proto.value_info
|
n.name: n for n in graph_proto.value_info
|
||||||
}
|
}
|
||||||
self.declared_input_map: dict[str, onnx.ValueInfoProto] = {
|
self.declared_input_map: Dict[str, onnx.ValueInfoProto] = {
|
||||||
n.name: n for n in graph_proto.input
|
n.name: n for n in graph_proto.input
|
||||||
}
|
}
|
||||||
self.output_map: dict[str, onnx.ValueInfoProto] = {
|
self.output_map: Dict[str, onnx.ValueInfoProto] = {
|
||||||
n.name: n for n in graph_proto.output
|
n.name: n for n in graph_proto.output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ class NodeImporter:
|
||||||
self._gi = graph_info
|
self._gi = graph_info
|
||||||
self._p = parent_op
|
self._p = parent_op
|
||||||
self._b = block
|
self._b = block
|
||||||
self._nv_map: dict[str, Value] = {}
|
self._nv_map: Dict[str, Value] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_function(
|
def define_function(
|
||||||
|
@ -225,7 +225,7 @@ class NodeImporter:
|
||||||
with container_op.context:
|
with container_op.context:
|
||||||
i64_type = IntegerType.get_signed(64)
|
i64_type = IntegerType.get_signed(64)
|
||||||
default_opset_version = 0
|
default_opset_version = 0
|
||||||
opset_versions: dict[str, IntegerAttr] = {}
|
opset_versions: Dict[str, IntegerAttr] = {}
|
||||||
for opset_import in m.opset_import:
|
for opset_import in m.opset_import:
|
||||||
if opset_import.domain:
|
if opset_import.domain:
|
||||||
opset_versions[opset_import.domain] = IntegerAttr.get(
|
opset_versions[opset_import.domain] = IntegerAttr.get(
|
||||||
|
@ -335,7 +335,7 @@ class NodeImporter:
|
||||||
for output_name, output_value in zip(output_names, custom_op.results):
|
for output_name, output_value in zip(output_names, custom_op.results):
|
||||||
self._nv_map[output_name] = output_value
|
self._nv_map[output_name] = output_value
|
||||||
|
|
||||||
def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]):
|
def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
|
||||||
attrs = {}
|
attrs = {}
|
||||||
for onnx_attr in onnx_attrs:
|
for onnx_attr in onnx_attrs:
|
||||||
attr_type = onnx_attr.type
|
attr_type = onnx_attr.type
|
||||||
|
@ -358,14 +358,14 @@ class NodeImporter:
|
||||||
attrs[f"torch.onnx.{onnx_attr.name}"] = result
|
attrs[f"torch.onnx.{onnx_attr.name}"] = result
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def count_regions(self, onnx_attrs: list[onnx.AttributeProto]):
|
def count_regions(self, onnx_attrs: List[onnx.AttributeProto]):
|
||||||
count = 0
|
count = 0
|
||||||
for onnx_attr in onnx_attrs:
|
for onnx_attr in onnx_attrs:
|
||||||
if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH:
|
if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH:
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op):
|
def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op):
|
||||||
attr_map = {}
|
attr_map = {}
|
||||||
for onnx_attr in onnx_attrs:
|
for onnx_attr in onnx_attrs:
|
||||||
attr_type = onnx_attr.type
|
attr_type = onnx_attr.type
|
||||||
|
@ -458,10 +458,10 @@ class ContextCache:
|
||||||
|
|
||||||
def __init__(self, context: Context):
|
def __init__(self, context: Context):
|
||||||
self._c = context
|
self._c = context
|
||||||
self._elem_type_map: dict[int, IrType] = {}
|
self._elem_type_map: Dict[int, IrType] = {}
|
||||||
self._list_type_map:dict[str, IrType] = {}
|
self._list_type_map:Dict[str, IrType] = {}
|
||||||
self._optional_type_map:dict[str, IrType] = {}
|
self._optional_type_map:Dict[str, IrType] = {}
|
||||||
self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {}
|
self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {}
|
||||||
|
|
||||||
def tensor_element_type(self, elem_type: int) -> IrType:
|
def tensor_element_type(self, elem_type: int) -> IrType:
|
||||||
t = self._elem_type_map.get(elem_type)
|
t = self._elem_type_map.get(elem_type)
|
||||||
|
@ -539,7 +539,7 @@ class ContextCache:
|
||||||
f"Unsupport optional element type")
|
f"Unsupport optional element type")
|
||||||
|
|
||||||
def get_vtensor_type(
|
def get_vtensor_type(
|
||||||
self, dims: tuple[Optional[int]], element_type: IrType
|
self, dims: Tuple[Optional[int]], element_type: IrType
|
||||||
) -> IrType:
|
) -> IrType:
|
||||||
key = (dims, element_type)
|
key = (dims, element_type)
|
||||||
t = self._vtensor_type_map.get(key)
|
t = self._vtensor_type_map.get(key)
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -229,6 +229,7 @@ setup(
|
||||||
"build_py": CMakeBuild,
|
"build_py": CMakeBuild,
|
||||||
},
|
},
|
||||||
ext_modules=EXT_MODULES,
|
ext_modules=EXT_MODULES,
|
||||||
|
python_requires=">=3.8",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
extras_require={
|
extras_require={
|
||||||
"onnx": [
|
"onnx": [
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
# RUN: %PYTHON %s | FileCheck %s
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, Tuple, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.export
|
import torch.export
|
||||||
|
@ -80,7 +80,7 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
|
||||||
|
|
||||||
|
|
||||||
def sparse_export(
|
def sparse_export(
|
||||||
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None
|
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
||||||
) -> torch.export.ExportedProgram:
|
) -> torch.export.ExportedProgram:
|
||||||
"""
|
"""
|
||||||
This is a ***temporary*** wrapper around `torch.export.export`
|
This is a ***temporary*** wrapper around `torch.export.export`
|
||||||
|
|
Loading…
Reference in New Issue