Add per-test timeouts to catch infinite loops (#3650)

Previously we only had full suite timeouts, making it impossible to
identify
which specific tests were hanging. This patch adds:

1. Per-test timeout support in the test framework
2. A default 600s timeout for all tests
3. A deliberately slow test to verify the timeout mechanism works

The timeout is implemented using Python's signal module. Tests that
exceed
their timeout are marked as failures with an appropriate error message.

This should help catch and isolate problematic tests that enter infinite
loops, without needing to re-run the entire suite multiple times.
pull/3659/head
Xida Ren (Cedar) 2024-08-21 14:37:31 -04:00 committed by GitHub
parent 7f886cc270
commit 4358aaccd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 126 additions and 33 deletions

View File

@ -372,6 +372,7 @@ TORCHDYNAMO_CRASHING_SET = {
} }
FX_IMPORTER_XFAIL_SET = { FX_IMPORTER_XFAIL_SET = {
"TimeOutModule_basic", # this test is expected to time out
"ReduceAnyDimFloatModule_basic", "ReduceAnyDimFloatModule_basic",
"AddFloatIntModule_basic", "AddFloatIntModule_basic",
"AllBoolFalseModule_basic", "AllBoolFalseModule_basic",
@ -2302,6 +2303,8 @@ LTC_XFAIL_SET = {
} }
ONNX_XFAIL_SET = { ONNX_XFAIL_SET = {
# This test is expected to time out
"TimeOutModule_basic",
# Failure - cast error # Failure - cast error
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
# Failure - incorrect numerics # Failure - incorrect numerics

View File

@ -27,6 +27,7 @@ from itertools import repeat
import os import os
import sys import sys
import traceback import traceback
import signal
import multiprocess as mp import multiprocess as mp
from multiprocess import set_start_method from multiprocess import set_start_method
@ -230,6 +231,7 @@ class Test(NamedTuple):
# module, actually). # module, actually).
# The secon parameter is a `TestUtils` instance for convenience. # The secon parameter is a `TestUtils` instance for convenience.
program_invoker: Callable[[Any, TestUtils], None] program_invoker: Callable[[Any, TestUtils], None]
timeout_seconds: int
class TestResult(NamedTuple): class TestResult(NamedTuple):
@ -305,43 +307,79 @@ def generate_golden_trace(test: Test) -> Trace:
return trace return trace
class timeout:
def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
def __exit__(self, type, value, traceback):
signal.alarm(0)
def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any:
try: with timeout(seconds=test.timeout_seconds):
golden_trace = generate_golden_trace(test) try:
if verbose: golden_trace = generate_golden_trace(test)
print(f"Compiling {test.unique_name}...", file=sys.stderr) if verbose:
compiled = config.compile(test.program_factory(), verbose=verbose) print(f"Compiling {test.unique_name}...", file=sys.stderr)
except Exception as e: compiled = config.compile(test.program_factory(), verbose=verbose)
return TestResult( except TimeoutError:
unique_name=test.unique_name, return TestResult(
compilation_error="".join( unique_name=test.unique_name,
traceback.format_exception(type(e), e, e.__traceback__) compilation_error=f"Test timed out during compilation (timeout={test.timeout_seconds}s)",
), runtime_error=None,
runtime_error=None, trace=None,
trace=None, golden_trace=None,
golden_trace=None, )
) except Exception as e:
try: return TestResult(
if verbose: unique_name=test.unique_name,
print(f"Running {test.unique_name}...", file=sys.stderr) compilation_error="".join(
trace = config.run(compiled, golden_trace) traceback.format_exception(type(e), e, e.__traceback__)
except Exception as e: ),
runtime_error=None,
trace=None,
golden_trace=None,
)
try:
if verbose:
print(f"Running {test.unique_name}...", file=sys.stderr)
trace = config.run(compiled, golden_trace)
# Disable the alarm
signal.alarm(0)
except TimeoutError:
return TestResult(
unique_name=test.unique_name,
compilation_error=None,
runtime_error="Test timed out during execution (timeout={test.timeout}s)",
trace=None,
golden_trace=None,
)
except Exception as e:
return TestResult(
unique_name=test.unique_name,
compilation_error=None,
runtime_error="".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
trace=None,
golden_trace=None,
)
return TestResult( return TestResult(
unique_name=test.unique_name, unique_name=test.unique_name,
compilation_error=None, compilation_error=None,
runtime_error="".join( runtime_error=None,
traceback.format_exception(type(e), e, e.__traceback__) trace=clone_trace(trace),
), golden_trace=clone_trace(golden_trace),
trace=None,
golden_trace=None,
) )
return TestResult(
unique_name=test.unique_name,
compilation_error=None,
runtime_error=None,
trace=clone_trace(trace),
golden_trace=clone_trace(golden_trace),
)
def run_tests( def run_tests(

View File

@ -15,7 +15,9 @@ GLOBAL_TEST_REGISTRY = []
_SEEN_UNIQUE_NAMES = set() _SEEN_UNIQUE_NAMES = set()
def register_test_case(module_factory: Callable[[], torch.nn.Module]): def register_test_case(
module_factory: Callable[[], torch.nn.Module], timeout_seconds: int = 120
):
"""Convenient decorator-based test registration. """Convenient decorator-based test registration.
Adds a `framework.Test` to the global test registry based on the decorated Adds a `framework.Test` to the global test registry based on the decorated
@ -38,6 +40,7 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]):
unique_name=f.__name__, unique_name=f.__name__,
program_factory=module_factory, program_factory=module_factory,
program_invoker=f, program_invoker=f,
timeout_seconds=timeout_seconds,
) )
) )
return f return f

View File

@ -17,6 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
"TimeOutModule_basic", # This test is expected to time out
} }
@ -60,3 +61,4 @@ def register_all_tests():
from . import diagonal from . import diagonal
from . import gridsampler from . import gridsampler
from . import meshgrid from . import meshgrid
from . import timeout

View File

@ -0,0 +1,47 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class TimeOutModule(torch.nn.Module):
"""
This test ensures that the timeout mechanism works as expected.
The module runs an infinite loop that will never terminate,
and the test is expected to time out and get terminated
"""
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1], torch.int64, True)])
def forward(self, x):
"""
Run an infinite loop.
This may loop in the compiler or the runtime depending on whether
fx or torchscript is used.
"""
# input_arg_2 is going to be 2
# but we can't just specify it as a
# constant because the compiler will
# attempt to get rid of the whole loop
input_arg_2 = x.size(0)
sum = 100
while input_arg_2 < sum: # sum will always > 2
sum += 1
return sum
@register_test_case(module_factory=lambda: TimeOutModule(), timeout_seconds=10)
def TimeOutModule_basic(module, tu: TestUtils):
module.forward(torch.ones((42, 42)))