mirror of https://github.com/llvm/torch-mlir
[FxImporter] Synchronize the collection of symbolic torch ops (#3236)
parent
5684dc0441
commit
9f64748f97
|
@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
|
|||
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
|
||||
|
||||
if _IS_TORCH_2_1_OR_EARLIER:
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_numel,
|
||||
}
|
||||
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
||||
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
||||
|
@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER:
|
|||
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
||||
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
||||
}
|
||||
else:
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
}
|
||||
|
||||
SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||
else:
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
|
||||
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
|
||||
|
@ -264,6 +254,8 @@ else:
|
|||
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
|
||||
}
|
||||
|
||||
SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SparsityMeta:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Optional, Union, Dict, Tuple, Any
|
||||
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
||||
|
||||
import warnings
|
||||
|
||||
|
@ -25,7 +25,7 @@ def export_and_import(
|
|||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
experimental_support_mutation: bool = False,
|
||||
hooks: Optional[FxImporterHooks] = None,
|
||||
decomposition_table: Optional[list] = None,
|
||||
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||
func_name: str = "main",
|
||||
enable_graph_printing: bool = False,
|
||||
**kwargs,
|
||||
|
|
Loading…
Reference in New Issue