mirror of https://github.com/llvm/torch-mlir
Add per-test timeouts to catch infinite loops (#3650)
Previously we only had full suite timeouts, making it impossible to identify which specific tests were hanging. This patch adds: 1. Per-test timeout support in the test framework 2. A default 600s timeout for all tests 3. A deliberately slow test to verify the timeout mechanism works The timeout is implemented using Python's signal module. Tests that exceed their timeout are marked as failures with an appropriate error message. This should help catch and isolate problematic tests that enter infinite loops, without needing to re-run the entire suite multiple times.pull/3659/head
parent
7f886cc270
commit
4358aaccd6
|
@ -372,6 +372,7 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
FX_IMPORTER_XFAIL_SET = {
|
||||
"TimeOutModule_basic", # this test is expected to time out
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"AddFloatIntModule_basic",
|
||||
"AllBoolFalseModule_basic",
|
||||
|
@ -2302,6 +2303,8 @@ LTC_XFAIL_SET = {
|
|||
}
|
||||
|
||||
ONNX_XFAIL_SET = {
|
||||
# This test is expected to time out
|
||||
"TimeOutModule_basic",
|
||||
# Failure - cast error
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
# Failure - incorrect numerics
|
||||
|
|
|
@ -27,6 +27,7 @@ from itertools import repeat
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import signal
|
||||
|
||||
import multiprocess as mp
|
||||
from multiprocess import set_start_method
|
||||
|
@ -230,6 +231,7 @@ class Test(NamedTuple):
|
|||
# module, actually).
|
||||
# The secon parameter is a `TestUtils` instance for convenience.
|
||||
program_invoker: Callable[[Any, TestUtils], None]
|
||||
timeout_seconds: int
|
||||
|
||||
|
||||
class TestResult(NamedTuple):
|
||||
|
@ -305,12 +307,37 @@ def generate_golden_trace(test: Test) -> Trace:
|
|||
return trace
|
||||
|
||||
|
||||
class timeout:
|
||||
def __init__(self, seconds=1, error_message="Timeout"):
|
||||
self.seconds = seconds
|
||||
self.error_message = error_message
|
||||
|
||||
def handle_timeout(self, signum, frame):
|
||||
raise TimeoutError(self.error_message)
|
||||
|
||||
def __enter__(self):
|
||||
signal.signal(signal.SIGALRM, self.handle_timeout)
|
||||
signal.alarm(self.seconds)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any:
|
||||
with timeout(seconds=test.timeout_seconds):
|
||||
try:
|
||||
golden_trace = generate_golden_trace(test)
|
||||
if verbose:
|
||||
print(f"Compiling {test.unique_name}...", file=sys.stderr)
|
||||
compiled = config.compile(test.program_factory(), verbose=verbose)
|
||||
except TimeoutError:
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
compilation_error=f"Test timed out during compilation (timeout={test.timeout_seconds}s)",
|
||||
runtime_error=None,
|
||||
trace=None,
|
||||
golden_trace=None,
|
||||
)
|
||||
except Exception as e:
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
|
@ -325,6 +352,17 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any:
|
|||
if verbose:
|
||||
print(f"Running {test.unique_name}...", file=sys.stderr)
|
||||
trace = config.run(compiled, golden_trace)
|
||||
|
||||
# Disable the alarm
|
||||
signal.alarm(0)
|
||||
except TimeoutError:
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error="Test timed out during execution (timeout={test.timeout}s)",
|
||||
trace=None,
|
||||
golden_trace=None,
|
||||
)
|
||||
except Exception as e:
|
||||
return TestResult(
|
||||
unique_name=test.unique_name,
|
||||
|
|
|
@ -15,7 +15,9 @@ GLOBAL_TEST_REGISTRY = []
|
|||
_SEEN_UNIQUE_NAMES = set()
|
||||
|
||||
|
||||
def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
||||
def register_test_case(
|
||||
module_factory: Callable[[], torch.nn.Module], timeout_seconds: int = 120
|
||||
):
|
||||
"""Convenient decorator-based test registration.
|
||||
|
||||
Adds a `framework.Test` to the global test registry based on the decorated
|
||||
|
@ -38,6 +40,7 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
|||
unique_name=f.__name__,
|
||||
program_factory=module_factory,
|
||||
program_invoker=f,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
)
|
||||
return f
|
||||
|
|
|
@ -17,6 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"TimeOutModule_basic", # This test is expected to time out
|
||||
}
|
||||
|
||||
|
||||
|
@ -60,3 +61,4 @@ def register_all_tests():
|
|||
from . import diagonal
|
||||
from . import gridsampler
|
||||
from . import meshgrid
|
||||
from . import timeout
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class TimeOutModule(torch.nn.Module):
|
||||
"""
|
||||
This test ensures that the timeout mechanism works as expected.
|
||||
|
||||
The module runs an infinite loop that will never terminate,
|
||||
and the test is expected to time out and get terminated
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
"""
|
||||
Run an infinite loop.
|
||||
|
||||
This may loop in the compiler or the runtime depending on whether
|
||||
fx or torchscript is used.
|
||||
"""
|
||||
# input_arg_2 is going to be 2
|
||||
# but we can't just specify it as a
|
||||
# constant because the compiler will
|
||||
# attempt to get rid of the whole loop
|
||||
input_arg_2 = x.size(0)
|
||||
sum = 100
|
||||
while input_arg_2 < sum: # sum will always > 2
|
||||
sum += 1
|
||||
return sum
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TimeOutModule(), timeout_seconds=10)
|
||||
def TimeOutModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones((42, 42)))
|
Loading…
Reference in New Issue