diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 8dddede2d..0b3da8ad2 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -48,9 +48,12 @@ DEFAULT_DECOMPOSITIONS = [ torch.ops.aten.triu.default, torch.ops.aten.nan_to_num.default, torch.ops.aten.unbind, - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, torch.ops.aten.diag, ] +if hasattr(torch.ops.aten, "_scaled_dot_product_flash_attention_for_cpu"): + DEFAULT_DECOMPOSITIONS.append( + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu + ) def get_decomposition_table(): diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c95df2504..91d81de01 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1081,6 +1081,11 @@ class ContextCache: mutable: bool = False, ): if tensor_meta is not None: + # separately handle when tensor_meta is a list. + if isinstance(val, list) and all( + isinstance(x, TorchFakeTensor) for x in val + ): + return IrType.parse("!torch.list", context=self._c) assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, # so throw error instead of silently doing wrong thing. diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 0d9ad77d2..d26e79afb 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -4,6 +4,7 @@ # Also available under a BSD-style license. See LICENSE. from typing import Optional, Union, Dict, Tuple, Any, Callable +from packaging import version import warnings @@ -70,7 +71,11 @@ def export_and_import( if isinstance(f, ExportedProgram): prog = f else: - prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) + # pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export + if version.Version(torch.__version__) >= version.Version("2.2.0"): + prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) + else: + prog = torch.export.export(f, args, kwargs) if decomposition_table is None: decomposition_table = get_decomposition_table() if decomposition_table: