mirror of https://github.com/llvm/torch-mlir
[FxImporter] small fixes for fx importer compatibility issues between different pytorch versions (#3577)
parent
edc87fc577
commit
6f7a5db801
|
@ -48,9 +48,12 @@ DEFAULT_DECOMPOSITIONS = [
|
|||
torch.ops.aten.triu.default,
|
||||
torch.ops.aten.nan_to_num.default,
|
||||
torch.ops.aten.unbind,
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
|
||||
torch.ops.aten.diag,
|
||||
]
|
||||
if hasattr(torch.ops.aten, "_scaled_dot_product_flash_attention_for_cpu"):
|
||||
DEFAULT_DECOMPOSITIONS.append(
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu
|
||||
)
|
||||
|
||||
|
||||
def get_decomposition_table():
|
||||
|
|
|
@ -1081,6 +1081,11 @@ class ContextCache:
|
|||
mutable: bool = False,
|
||||
):
|
||||
if tensor_meta is not None:
|
||||
# separately handle when tensor_meta is a list.
|
||||
if isinstance(val, list) and all(
|
||||
isinstance(x, TorchFakeTensor) for x in val
|
||||
):
|
||||
return IrType.parse("!torch.list<vtensor>", context=self._c)
|
||||
assert isinstance(tensor_meta, TensorMetadata)
|
||||
# Quantized tensor meta data is not preserved in our lowering,
|
||||
# so throw error instead of silently doing wrong thing.
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
||||
from packaging import version
|
||||
|
||||
import warnings
|
||||
|
||||
|
@ -70,7 +71,11 @@ def export_and_import(
|
|||
if isinstance(f, ExportedProgram):
|
||||
prog = f
|
||||
else:
|
||||
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
|
||||
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
|
||||
if version.Version(torch.__version__) >= version.Version("2.2.0"):
|
||||
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
|
||||
else:
|
||||
prog = torch.export.export(f, args, kwargs)
|
||||
if decomposition_table is None:
|
||||
decomposition_table = get_decomposition_table()
|
||||
if decomposition_table:
|
||||
|
|
Loading…
Reference in New Issue