diff --git a/python/torch_mlir_e2e_test/test_suite/index_put.py b/python/torch_mlir_e2e_test/test_suite/index_put.py index 838f0496d..e12b2627c 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_put.py +++ b/python/torch_mlir_e2e_test/test_suite/index_put.py @@ -133,9 +133,10 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module): ([-1], torch.float32, True), ]) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index,), value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_(input, (index, ), + value, + accumulate=True, + unsafe=False) @register_test_case( @@ -211,9 +212,10 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module): ([-1], torch.int64, True), ]) def forward(self, input, index, value): - return torch.ops.aten._index_put_impl_(input, (index,), value, - accumulate=True, - unsafe=False) + return torch.ops.aten._index_put_impl_(input, (index, ), + value, + accumulate=True, + unsafe=False) @register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule()) diff --git a/python/torch_mlir_e2e_test/torchscript/framework.py b/python/torch_mlir_e2e_test/torchscript/framework.py index 135969df2..080fbe561 100644 --- a/python/torch_mlir_e2e_test/torchscript/framework.py +++ b/python/torch_mlir_e2e_test/torchscript/framework.py @@ -26,7 +26,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Di import traceback import torch - +import torch.multiprocessing as mp TorchScriptValue = Union[int, float, List['TorchScriptValue'], Dict['TorchScriptValue', @@ -53,10 +53,6 @@ class TraceItem(NamedTuple): # public boundary. Trace = List[TraceItem] -# A type shared between the result of `TestConfig.compile` and the input -# to `TestConfig.run`. Each backend will likely have a different definition of -# this type. -CompiledArtifact = TypeVar('CompiledArtifact') # Clone all the tensor values. def clone_torch_script_value(v: TorchScriptValue): @@ -76,6 +72,26 @@ def clone_torch_script_value(v: TorchScriptValue): assert False, "unhandled cloning of TorchScriptValue value type" +# This clone helper is used to work around issues with output tensors when +# using multiprocessing module to run tests. The error happens for tests like +# ContiguousModule_basic where the output tensor aliases with an input tensor. +# When the output tensor is not cloned, the testing trace would be modified for +# unknown reason when passed through the shared memory through synchronized +# queue for example. +# TODO: Figure out the root cause of the failure and fix properly. +def clone_trace(trace: Trace) -> Trace: + return [ + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=clone_torch_script_value(item.output)) + for item in trace + ] + + +# A type shared between the result of `TestConfig.compile` and the input +# to `TestConfig.run`. Each backend will likely have a different definition of +# this type. +CompiledArtifact = TypeVar('CompiledArtifact') class TestConfig(abc.ABC): """The interface implemented by backends to run tests. @@ -158,6 +174,7 @@ class TestUtils: TODO: Figure out how to seed reset properly scoped to just a test case (such as when running tests in parallel) """ + def __init__(self): torch.manual_seed(0) @@ -260,40 +277,103 @@ def generate_golden_trace(test: Test) -> Trace: return trace +def compile_and_run_test(test: Test, config: TestConfig) -> Any: + try: + golden_trace = generate_golden_trace(test) + compiled = config.compile(test.program_factory()) + except Exception as e: + return TestResult(unique_name=test.unique_name, + compilation_error="".join( + traceback.format_exception( + type(e), e, e.__traceback__)), + runtime_error=None, + trace=None, + golden_trace=None) + try: + trace = config.run(compiled, golden_trace) + except Exception as e: + return TestResult(unique_name=test.unique_name, + compilation_error=None, + runtime_error="".join( + traceback.format_exception( + type(e), e, e.__traceback__)), + trace=None, + golden_trace=None) + return TestResult(unique_name=test.unique_name, + compilation_error=None, + runtime_error=None, + trace=clone_trace(trace), + golden_trace=golden_trace) + + +queue_sentinel = "QUEUE_SENTINEL" + + +def run_workers_in_parallel(task_queue: mp.Queue, worker): + NUMBER_OF_PROCESSES = min(int(mp.cpu_count() * 1.1), task_queue.qsize()) + + # TODO: We've noticed that on certain 2 core machine parallelizing the tests + # makes the llvm backend legacy pass manager 20x slower than using a + # single process. Need to investigate the root cause eventually. This is a + # hack to work around this issue. + if mp.cpu_count() == 2: + NUMBER_OF_PROCESSES = 1 + + processes = [] + for i in range(NUMBER_OF_PROCESSES): + p = mp.Process(target=worker, args=(task_queue, )) + p.start() + processes.append(p) + for i in range(NUMBER_OF_PROCESSES): + task_queue.put(queue_sentinel) + for p in processes: + p.join() + + def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: """Invoke the given `Test`'s with the provided `TestConfig`.""" - results = [] - for test in tests: - # TODO: Precompile everything in parallel. - try: - golden_trace = generate_golden_trace(test) - compiled = config.compile(test.program_factory()) - except Exception as e: - results.append( - TestResult(unique_name=test.unique_name, - compilation_error="".join(traceback.format_exception( - type(e), e, e.__traceback__)), - runtime_error=None, - trace=None, - golden_trace=None)) - continue - # TODO: Run in parallel. - try: - trace = config.run(compiled, golden_trace) - except Exception as e: - results.append( - TestResult(unique_name=test.unique_name, - compilation_error=None, - runtime_error="".join(traceback.format_exception( - type(e), e, e.__traceback__)), - trace=None, - golden_trace=None)) - continue - results.append( - TestResult(unique_name=test.unique_name, - compilation_error=None, - runtime_error=None, - trace=trace, - golden_trace=golden_trace)) + # To run e2e tests in parallel: + # The tests are put into a synchronized queue. Multiple worker processes are + # created. Each worker takes one test at a time from the queue to compile + # and execute it. If the test finishes, whether failed or passed, the result + # of the test is put into a synchronized list which collects the tests + # results from all worker processes. + manager = mp.Manager() + tests_queue = manager.Queue() + sync_results = manager.list() + # This is needed because autograd does not support crossing process + # boundaries. + torch.autograd.set_grad_enabled(False) + for test in tests: + tests_queue.put(test.unique_name) + + tests_dict = {test.unique_name: test for test in tests} + + def worker(tests_queue: mp.Queue): + for test_name in iter(tests_queue.get, queue_sentinel): + sync_results.append( + compile_and_run_test(tests_dict[test_name], config)) + + run_workers_in_parallel(tests_queue, worker) + tests_with_results = {result.unique_name for result in sync_results} + all_tests = {test.unique_name for test in tests} + # For processes that are crashed due to compile time or runtime error, + # the error outputs are printed out all together but no TestResult is + # produced when the process crashed. + # TODO: Find a clean way to capture the output from crashed process and + # create more detailed runtime_error for those tests. + aborted_tests = all_tests - tests_with_results + aborted_tests_results = [ + TestResult( + unique_name=aborted_test_name, + compilation_error=None, + runtime_error= + "Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n", + trace=None, + golden_trace=None) for aborted_test_name in aborted_tests + ] + results = [result for result in sync_results] + results.extend(aborted_tests_results) + results.sort(key=lambda result: result.unique_name) return results