diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 40a164007..d48223ad4 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -34,37 +34,37 @@ from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'linalg', 'mhlo', 'tosa', 'lazy_tensor_core', 'torchdynamo'] - parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') - parser.add_argument('-c', '--config', + config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"] + parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") + parser.add_argument("-c", "--config", choices=config_choices, - default='linalg', - help=f''' + default="linalg", + help=f""" Meaning of options: -"linalg": run through torch-mlir's default Linalg-on-Tensors backend. -"mhlo": run through torch-mlir's default MHLO backend. -"tosa": run through torch-mlir's default TOSA backend. +"linalg": run through torch-mlir"s default Linalg-on-Tensors backend. +"mhlo": run through torch-mlir"s default MHLO backend. +"tosa": run through torch-mlir"s default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. "torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors. -''') - parser.add_argument('-f', '--filter', default='.*', help=''' +""") + parser.add_argument("-f", "--filter", default=".*", help=""" Regular expression specifying which tests to include in this run. -''') - parser.add_argument('-v', '--verbose', +""") + parser.add_argument("-v", "--verbose", default=False, - action='store_true', - help='report test results with additional detail') - parser.add_argument('-s', '--sequential', + action="store_true", + help="report test results with additional detail") + parser.add_argument("-s", "--sequential", default=False, - action='store_true', - help='''Run tests sequentially rather than in parallel. + action="store_true", + help="""Run tests sequentially rather than in parallel. This can be useful for debugging, since it runs the tests in the same process, -which make it easier to attach a debugger or get a stack trace.''') - parser.add_argument('--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed', - metavar="TEST", type=str, nargs='+', - help='A set of tests to not attempt to run, since they crash and cannot be XFAILed.') +which make it easier to attach a debugger or get a stack trace.""") + parser.add_argument("--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed", + metavar="TEST", type=str, nargs="+", + help="A set of tests to not attempt to run, since they crash and cannot be XFAILed.") return parser def main(): @@ -74,25 +74,25 @@ def main(): test.unique_name for test in GLOBAL_TEST_REGISTRY) # Find the selected config. - if args.config == 'linalg': + if args.config == "linalg": config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET - if args.config == 'tosa': + if args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET - if args.config == 'mhlo': + if args.config == "mhlo": config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) xfail_set = all_test_unique_names - MHLO_PASS_SET - elif args.config == 'native_torch': + elif args.config == "native_torch": config = NativeTorchTestConfig() xfail_set = {} - elif args.config == 'torchscript': + elif args.config == "torchscript": config = TorchScriptTestConfig() xfail_set = {} - elif args.config == 'lazy_tensor_core': + elif args.config == "lazy_tensor_core": config = LazyTensorCoreTestConfig() xfail_set = LTC_XFAIL_SET - elif args.config == 'torchdynamo': + elif args.config == "torchdynamo": config = TorchDynamoTestConfig() xfail_set = TORCHDYNAMO_XFAIL_SET @@ -101,7 +101,7 @@ def main(): if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None: for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed: if arg not in all_test_unique_names: - print(f'ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument "{arg}" is not a valid test name') + print(f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name") sys.exit(1) # Find the selected tests, and emit a diagnostic if none are found. @@ -111,9 +111,9 @@ def main(): ] if len(tests) == 0: print( - f'ERROR: the provided filter {args.filter!r} does not match any tests' + f"ERROR: the provided filter {args.filter!r} does not match any tests" ) - print('The available tests are:') + print("The available tests are:") for test in available_tests: print(test.unique_name) sys.exit(1) @@ -132,6 +132,6 @@ def _suppress_warnings(): warnings.filterwarnings("ignore", message="A builtin ctypes object gave a PEP3118 format string that does not match its itemsize") -if __name__ == '__main__': +if __name__ == "__main__": _suppress_warnings() main()