From 4358aaccd69939379b2dc875ba7beebef166cc64 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 21 Aug 2024 14:37:31 -0400 Subject: [PATCH] Add per-test timeouts to catch infinite loops (#3650) Previously we only had full suite timeouts, making it impossible to identify which specific tests were hanging. This patch adds: 1. Per-test timeout support in the test framework 2. A default 600s timeout for all tests 3. A deliberately slow test to verify the timeout mechanism works The timeout is implemented using Python's signal module. Tests that exceed their timeout are marked as failures with an appropriate error message. This should help catch and isolate problematic tests that enter infinite loops, without needing to re-run the entire suite multiple times. --- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../python/torch_mlir_e2e_test/framework.py | 102 ++++++++++++------ .../python/torch_mlir_e2e_test/registry.py | 5 +- .../test_suite/__init__.py | 2 + .../torch_mlir_e2e_test/test_suite/timeout.py | 47 ++++++++ 5 files changed, 126 insertions(+), 33 deletions(-) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 044c8154f..5c613eae0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -372,6 +372,7 @@ TORCHDYNAMO_CRASHING_SET = { } FX_IMPORTER_XFAIL_SET = { + "TimeOutModule_basic", # this test is expected to time out "ReduceAnyDimFloatModule_basic", "AddFloatIntModule_basic", "AllBoolFalseModule_basic", @@ -2302,6 +2303,8 @@ LTC_XFAIL_SET = { } ONNX_XFAIL_SET = { + # This test is expected to time out + "TimeOutModule_basic", # Failure - cast error "PermuteNegativeIndexModule_basic", # Failure - incorrect numerics diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index 89d802349..c24af96f3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -27,6 +27,7 @@ from itertools import repeat import os import sys import traceback +import signal import multiprocess as mp from multiprocess import set_start_method @@ -230,6 +231,7 @@ class Test(NamedTuple): # module, actually). # The secon parameter is a `TestUtils` instance for convenience. program_invoker: Callable[[Any, TestUtils], None] + timeout_seconds: int class TestResult(NamedTuple): @@ -305,43 +307,79 @@ def generate_golden_trace(test: Test) -> Trace: return trace +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: - try: - golden_trace = generate_golden_trace(test) - if verbose: - print(f"Compiling {test.unique_name}...", file=sys.stderr) - compiled = config.compile(test.program_factory(), verbose=verbose) - except Exception as e: - return TestResult( - unique_name=test.unique_name, - compilation_error="".join( - traceback.format_exception(type(e), e, e.__traceback__) - ), - runtime_error=None, - trace=None, - golden_trace=None, - ) - try: - if verbose: - print(f"Running {test.unique_name}...", file=sys.stderr) - trace = config.run(compiled, golden_trace) - except Exception as e: + with timeout(seconds=test.timeout_seconds): + try: + golden_trace = generate_golden_trace(test) + if verbose: + print(f"Compiling {test.unique_name}...", file=sys.stderr) + compiled = config.compile(test.program_factory(), verbose=verbose) + except TimeoutError: + return TestResult( + unique_name=test.unique_name, + compilation_error=f"Test timed out during compilation (timeout={test.timeout_seconds}s)", + runtime_error=None, + trace=None, + golden_trace=None, + ) + except Exception as e: + return TestResult( + unique_name=test.unique_name, + compilation_error="".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + runtime_error=None, + trace=None, + golden_trace=None, + ) + try: + if verbose: + print(f"Running {test.unique_name}...", file=sys.stderr) + trace = config.run(compiled, golden_trace) + + # Disable the alarm + signal.alarm(0) + except TimeoutError: + return TestResult( + unique_name=test.unique_name, + compilation_error=None, + runtime_error="Test timed out during execution (timeout={test.timeout}s)", + trace=None, + golden_trace=None, + ) + except Exception as e: + return TestResult( + unique_name=test.unique_name, + compilation_error=None, + runtime_error="".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + trace=None, + golden_trace=None, + ) return TestResult( unique_name=test.unique_name, compilation_error=None, - runtime_error="".join( - traceback.format_exception(type(e), e, e.__traceback__) - ), - trace=None, - golden_trace=None, + runtime_error=None, + trace=clone_trace(trace), + golden_trace=clone_trace(golden_trace), ) - return TestResult( - unique_name=test.unique_name, - compilation_error=None, - runtime_error=None, - trace=clone_trace(trace), - golden_trace=clone_trace(golden_trace), - ) def run_tests( diff --git a/projects/pt1/python/torch_mlir_e2e_test/registry.py b/projects/pt1/python/torch_mlir_e2e_test/registry.py index d2116bafe..a98a6d34e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/registry.py +++ b/projects/pt1/python/torch_mlir_e2e_test/registry.py @@ -15,7 +15,9 @@ GLOBAL_TEST_REGISTRY = [] _SEEN_UNIQUE_NAMES = set() -def register_test_case(module_factory: Callable[[], torch.nn.Module]): +def register_test_case( + module_factory: Callable[[], torch.nn.Module], timeout_seconds: int = 120 +): """Convenient decorator-based test registration. Adds a `framework.Test` to the global test registry based on the decorated @@ -38,6 +40,7 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]): unique_name=f.__name__, program_factory=module_factory, program_invoker=f, + timeout_seconds=timeout_seconds, ) ) return f diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index b90dff335..8166562b0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -17,6 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", + "TimeOutModule_basic", # This test is expected to time out } @@ -60,3 +61,4 @@ def register_all_tests(): from . import diagonal from . import gridsampler from . import meshgrid + from . import timeout diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py new file mode 100644 index 000000000..387ff6cfc --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py @@ -0,0 +1,47 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + + +# ============================================================================== +class TimeOutModule(torch.nn.Module): + """ + This test ensures that the timeout mechanism works as expected. + + The module runs an infinite loop that will never terminate, + and the test is expected to time out and get terminated + """ + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + """ + Run an infinite loop. + + This may loop in the compiler or the runtime depending on whether + fx or torchscript is used. + """ + # input_arg_2 is going to be 2 + # but we can't just specify it as a + # constant because the compiler will + # attempt to get rid of the whole loop + input_arg_2 = x.size(0) + sum = 100 + while input_arg_2 < sum: # sum will always > 2 + sum += 1 + return sum + + +@register_test_case(module_factory=lambda: TimeOutModule(), timeout_seconds=10) +def TimeOutModule_basic(module, tu: TestUtils): + module.forward(torch.ones((42, 42)))