Make a typing dependency that is not in older PyTorch backwards compatible. (#2948)

This was found in a downstream that is pegged to an older PyTorch
version.
pull/2949/head
Stella Laurenzo 2024-02-23 15:52:27 -08:00 committed by GitHub
parent ec2b80b433
commit 89e02c195b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 2 deletions

View File

@ -16,7 +16,18 @@ import operator
import re
from dataclasses import dataclass
from types import BuiltinMethodType, BuiltinFunctionType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
import weakref
import numpy as np
@ -45,6 +56,16 @@ from torch.fx import (
Node,
)
try:
from torch.export.graph_signature import InputSpec as TypingInputSpec
except ModuleNotFoundError:
# PyTorch prior to 2.3 is missing certain things we use in typing
# signatures. Just make them be Any.
if not TYPE_CHECKING:
TypingInputSpec = Any
else:
raise
try:
import ml_dtypes
except ModuleNotFoundError:
@ -299,7 +320,7 @@ class InputInfo:
"""Provides additional metadata when resolving inputs."""
program: torch.export.ExportedProgram
input_spec: torch.export.graph_signature.InputSpec
input_spec: TypingInputSpec
node: Node
ir_type: IrType
mutable_producer_node_name: Optional[str] = None