Ensure that tests have unique names

pull/1314/head snapshot-20220830.581
Sean Silva 2022-08-29 21:45:09 +00:00
parent 079bff33f1
commit 0f40d98009
2 changed files with 10 additions and 1 deletions

View File

@ -11,6 +11,8 @@ from .framework import Test
# The global registry of tests.
GLOBAL_TEST_REGISTRY = []
# Ensure that there are no duplicate names in the global test registry.
_SEEN_UNIQUE_NAMES = set()
def register_test_case(module_factory: Callable[[], torch.nn.Module]):
@ -22,6 +24,13 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]):
`program_invoker` is the decorated function.
"""
def decorator(f):
# Ensure that there are no duplicate names in the global test registry.
if f.__name__ in _SEEN_UNIQUE_NAMES:
raise Exception(
f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name.")
_SEEN_UNIQUE_NAMES.add(f.__name__)
# Store the test in the registry.
GLOBAL_TEST_REGISTRY.append(
Test(unique_name=f.__name__,
program_factory=module_factory,

View File

@ -515,7 +515,7 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
cudnn_enabled=True)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class ConvolutionModule2DGroups(torch.nn.Module):