mirror of https://github.com/llvm/torch-mlir
Fix version comparison against stable (#2209)
parent
3a1b92c463
commit
816880774b
|
@ -54,6 +54,7 @@ if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
|||
_dynamo_fx_importer.py
|
||||
compiler_utils.py
|
||||
dynamo.py
|
||||
_version.py
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from packaging import version
|
||||
import torch
|
||||
|
||||
def torch_version_for_comparison():
|
||||
# Ignore +cpu, +cu117m, etc. in comparisons
|
||||
return version.parse(torch.__version__.split("+", 1)[0])
|
|
@ -4,7 +4,7 @@
|
|||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import List
|
||||
from packaging import version
|
||||
from ._version import torch_version_for_comparison, version
|
||||
|
||||
import torch
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
|
@ -66,7 +66,7 @@ def _get_decomposition_table():
|
|||
aten.squeeze,
|
||||
]
|
||||
# TODO: enable test once 2.1.0 is stable
|
||||
if version.parse(torch.__version__) > version.parse("2.0.1+cpu"):
|
||||
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
|
||||
decomp_list += [aten._native_batch_norm_legit_no_training]
|
||||
return get_decompositions(decomp_list)
|
||||
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
# Lists of tests that fail to even reach the backends.
|
||||
# These represent further work needed in torch-mlir to lower them properly
|
||||
# to the backend contract.
|
||||
|
||||
from torch_mlir._version import torch_version_for_comparison, version
|
||||
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||
"NativeGroupNormModule_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
|
@ -14,15 +17,12 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
}
|
||||
|
||||
# TODO: Delete once torch 2.1.0 is released
|
||||
# check for torch version and disable tests
|
||||
TORCH_2_1_REQUIRED = {
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic"
|
||||
}
|
||||
import torch
|
||||
from packaging import version
|
||||
if not version.parse(torch.__version__) > version.parse("2.0.1+cpu"):
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS.update(TORCH_2_1_REQUIRED)
|
||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS.update({
|
||||
"ScaledDotProductAttentionDifferentModule_basic",
|
||||
"ScaledDotProductAttentionSameModule_basic"
|
||||
})
|
||||
|
||||
|
||||
def register_all_tests():
|
||||
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
|
||||
|
|
Loading…
Reference in New Issue