mirror of https://github.com/llvm/torch-mlir
Make a typing dependency that is not in older PyTorch backwards compatible. (#2948)
This was found in a downstream that is pegged to an older PyTorch version.pull/2949/head
parent
ec2b80b433
commit
89e02c195b
|
@ -16,7 +16,18 @@ import operator
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from types import BuiltinMethodType, BuiltinFunctionType
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
@ -45,6 +56,16 @@ from torch.fx import (
|
|||
Node,
|
||||
)
|
||||
|
||||
try:
|
||||
from torch.export.graph_signature import InputSpec as TypingInputSpec
|
||||
except ModuleNotFoundError:
|
||||
# PyTorch prior to 2.3 is missing certain things we use in typing
|
||||
# signatures. Just make them be Any.
|
||||
if not TYPE_CHECKING:
|
||||
TypingInputSpec = Any
|
||||
else:
|
||||
raise
|
||||
|
||||
try:
|
||||
import ml_dtypes
|
||||
except ModuleNotFoundError:
|
||||
|
@ -299,7 +320,7 @@ class InputInfo:
|
|||
"""Provides additional metadata when resolving inputs."""
|
||||
|
||||
program: torch.export.ExportedProgram
|
||||
input_spec: torch.export.graph_signature.InputSpec
|
||||
input_spec: TypingInputSpec
|
||||
node: Node
|
||||
ir_type: IrType
|
||||
mutable_producer_node_name: Optional[str] = None
|
||||
|
|
Loading…
Reference in New Issue