diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index aea76c621..91f3c27ee 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -16,7 +16,18 @@ import operator import re from dataclasses import dataclass from types import BuiltinMethodType, BuiltinFunctionType -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import weakref import numpy as np @@ -45,6 +56,16 @@ from torch.fx import ( Node, ) +try: + from torch.export.graph_signature import InputSpec as TypingInputSpec +except ModuleNotFoundError: + # PyTorch prior to 2.3 is missing certain things we use in typing + # signatures. Just make them be Any. + if not TYPE_CHECKING: + TypingInputSpec = Any + else: + raise + try: import ml_dtypes except ModuleNotFoundError: @@ -299,7 +320,7 @@ class InputInfo: """Provides additional metadata when resolving inputs.""" program: torch.export.ExportedProgram - input_spec: torch.export.graph_signature.InputSpec + input_spec: TypingInputSpec node: Node ir_type: IrType mutable_producer_node_name: Optional[str] = None