mirror of https://github.com/llvm/torch-mlir
parent
079bff33f1
commit
0f40d98009
|
@ -11,6 +11,8 @@ from .framework import Test
|
|||
|
||||
# The global registry of tests.
|
||||
GLOBAL_TEST_REGISTRY = []
|
||||
# Ensure that there are no duplicate names in the global test registry.
|
||||
_SEEN_UNIQUE_NAMES = set()
|
||||
|
||||
|
||||
def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
||||
|
@ -22,6 +24,13 @@ def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
|||
`program_invoker` is the decorated function.
|
||||
"""
|
||||
def decorator(f):
|
||||
# Ensure that there are no duplicate names in the global test registry.
|
||||
if f.__name__ in _SEEN_UNIQUE_NAMES:
|
||||
raise Exception(
|
||||
f"Duplicate test name: '{f.__name__}'. Please make sure that the function wrapped by `register_test_case` has a unique name.")
|
||||
_SEEN_UNIQUE_NAMES.add(f.__name__)
|
||||
|
||||
# Store the test in the registry.
|
||||
GLOBAL_TEST_REGISTRY.append(
|
||||
Test(unique_name=f.__name__,
|
||||
program_factory=module_factory,
|
||||
|
|
|
@ -515,7 +515,7 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
|||
cudnn_enabled=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
|
||||
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
|
||||
def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
class ConvolutionModule2DGroups(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue