[FxImporter] Synchronize the collection of symbolic torch ops (#3236)

pull/3256/head
penguin_wwy 2024-04-29 10:09:00 +08:00 committed by GitHub
parent 5684dc0441
commit 9f64748f97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 14 deletions

View File

@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
if _IS_TORCH_2_1_OR_EARLIER:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}
SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER:
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
else:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel.default,
}
SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
else:
SYMBOLIC_OP_TO_TORCH_OP = {
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
@ -264,6 +254,8 @@ else:
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
}
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
@dataclass(frozen=True)
class SparsityMeta:

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Optional, Union, Dict, Tuple, Any
from typing import Optional, Union, Dict, Tuple, Any, Callable
import warnings
@ -25,7 +25,7 @@ def export_and_import(
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
decomposition_table: Optional[list] = None,
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main",
enable_graph_printing: bool = False,
**kwargs,