Fix version comparison against stable (#2209)

pull/2198/head
Matthias Gehre 2023-06-07 10:19:38 +02:00 committed by GitHub
parent 3a1b92c463
commit 816880774b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 11 deletions

View File

@ -54,6 +54,7 @@ if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py
_version.py
)
endif()

View File

@ -0,0 +1,11 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from packaging import version
import torch
def torch_version_for_comparison():
# Ignore +cpu, +cu117m, etc. in comparisons
return version.parse(torch.__version__.split("+", 1)[0])

View File

@ -4,7 +4,7 @@
# Also available under a BSD-style license. See LICENSE.
from typing import List
from packaging import version
from ._version import torch_version_for_comparison, version
import torch
from torch._functorch.compile_utils import strip_overloads
@ -66,7 +66,7 @@ def _get_decomposition_table():
aten.squeeze,
]
# TODO: enable test once 2.1.0 is stable
if version.parse(torch.__version__) > version.parse("2.0.1+cpu"):
if torch_version_for_comparison() >= version.parse("2.1.0.dev"):
decomp_list += [aten._native_batch_norm_legit_no_training]
return get_decompositions(decomp_list)

View File

@ -6,6 +6,9 @@
# Lists of tests that fail to even reach the backends.
# These represent further work needed in torch-mlir to lower them properly
# to the backend contract.
from torch_mlir._version import torch_version_for_comparison, version
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"NativeGroupNormModule_basic",
"NativeGroupNormBackwardModule_basic",
@ -14,15 +17,12 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
}
# TODO: Delete once torch 2.1.0 is released
# check for torch version and disable tests
TORCH_2_1_REQUIRED = {
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionSameModule_basic"
}
import torch
from packaging import version
if not version.parse(torch.__version__) > version.parse("2.0.1+cpu"):
COMMON_TORCH_MLIR_LOWERING_XFAILS.update(TORCH_2_1_REQUIRED)
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
COMMON_TORCH_MLIR_LOWERING_XFAILS.update({
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionSameModule_basic"
})
def register_all_tests():
"""Registers all the built-in E2E tests that Torch-MLIR provides."""