[FxImporter] small fixes for fx importer compatibility issues between different pytorch versions (#3577)

pull/3516/head
Jiawei Wu 2024-08-01 10:52:41 +08:00 committed by GitHub
parent edc87fc577
commit 6f7a5db801
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 2 deletions

View File

@ -48,9 +48,12 @@ DEFAULT_DECOMPOSITIONS = [
torch.ops.aten.triu.default,
torch.ops.aten.nan_to_num.default,
torch.ops.aten.unbind,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten.diag,
]
if hasattr(torch.ops.aten, "_scaled_dot_product_flash_attention_for_cpu"):
DEFAULT_DECOMPOSITIONS.append(
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu
)
def get_decomposition_table():

View File

@ -1081,6 +1081,11 @@ class ContextCache:
mutable: bool = False,
):
if tensor_meta is not None:
# separately handle when tensor_meta is a list.
if isinstance(val, list) and all(
isinstance(x, TorchFakeTensor) for x in val
):
return IrType.parse("!torch.list<vtensor>", context=self._c)
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
# so throw error instead of silently doing wrong thing.

View File

@ -4,6 +4,7 @@
# Also available under a BSD-style license. See LICENSE.
from typing import Optional, Union, Dict, Tuple, Any, Callable
from packaging import version
import warnings
@ -70,7 +71,11 @@ def export_and_import(
if isinstance(f, ExportedProgram):
prog = f
else:
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
if version.Version(torch.__version__) >= version.Version("2.2.0"):
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
else:
prog = torch.export.export(f, args, kwargs)
if decomposition_table is None:
decomposition_table = get_decomposition_table()
if decomposition_table: