Make e2e testing parallel

This change makes the e2e testing parallel using the multiprocessing
python module.
pull/838/head
Yi Zhang 2022-03-23 11:34:02 -04:00
parent 96fabc0036
commit 2ed90741eb
2 changed files with 126 additions and 44 deletions

View File

@ -133,7 +133,8 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
([-1], torch.float32, True), ([-1], torch.float32, True),
]) ])
def forward(self, input, index, value): def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index,), value, return torch.ops.aten._index_put_impl_(input, (index, ),
value,
accumulate=True, accumulate=True,
unsafe=False) unsafe=False)
@ -211,7 +212,8 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
([-1], torch.int64, True), ([-1], torch.int64, True),
]) ])
def forward(self, input, index, value): def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index,), value, return torch.ops.aten._index_put_impl_(input, (index, ),
value,
accumulate=True, accumulate=True,
unsafe=False) unsafe=False)

View File

@ -26,7 +26,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Di
import traceback import traceback
import torch import torch
import torch.multiprocessing as mp
TorchScriptValue = Union[int, float, List['TorchScriptValue'], TorchScriptValue = Union[int, float, List['TorchScriptValue'],
Dict['TorchScriptValue', Dict['TorchScriptValue',
@ -53,10 +53,6 @@ class TraceItem(NamedTuple):
# public boundary. # public boundary.
Trace = List[TraceItem] Trace = List[TraceItem]
# A type shared between the result of `TestConfig.compile` and the input
# to `TestConfig.run`. Each backend will likely have a different definition of
# this type.
CompiledArtifact = TypeVar('CompiledArtifact')
# Clone all the tensor values. # Clone all the tensor values.
def clone_torch_script_value(v: TorchScriptValue): def clone_torch_script_value(v: TorchScriptValue):
@ -76,6 +72,26 @@ def clone_torch_script_value(v: TorchScriptValue):
assert False, "unhandled cloning of TorchScriptValue value type" assert False, "unhandled cloning of TorchScriptValue value type"
# This clone helper is used to work around issues with output tensors when
# using multiprocessing module to run tests. The error happens for tests like
# ContiguousModule_basic where the output tensor aliases with an input tensor.
# When the output tensor is not cloned, the testing trace would be modified for
# unknown reason when passed through the shared memory through synchronized
# queue for example.
# TODO: Figure out the root cause of the failure and fix properly.
def clone_trace(trace: Trace) -> Trace:
return [
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=clone_torch_script_value(item.output))
for item in trace
]
# A type shared between the result of `TestConfig.compile` and the input
# to `TestConfig.run`. Each backend will likely have a different definition of
# this type.
CompiledArtifact = TypeVar('CompiledArtifact')
class TestConfig(abc.ABC): class TestConfig(abc.ABC):
"""The interface implemented by backends to run tests. """The interface implemented by backends to run tests.
@ -158,6 +174,7 @@ class TestUtils:
TODO: Figure out how to seed reset properly scoped to just a test case TODO: Figure out how to seed reset properly scoped to just a test case
(such as when running tests in parallel) (such as when running tests in parallel)
""" """
def __init__(self): def __init__(self):
torch.manual_seed(0) torch.manual_seed(0)
@ -260,40 +277,103 @@ def generate_golden_trace(test: Test) -> Trace:
return trace return trace
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: def compile_and_run_test(test: Test, config: TestConfig) -> Any:
"""Invoke the given `Test`'s with the provided `TestConfig`."""
results = []
for test in tests:
# TODO: Precompile everything in parallel.
try: try:
golden_trace = generate_golden_trace(test) golden_trace = generate_golden_trace(test)
compiled = config.compile(test.program_factory()) compiled = config.compile(test.program_factory())
except Exception as e: except Exception as e:
results.append( return TestResult(unique_name=test.unique_name,
TestResult(unique_name=test.unique_name, compilation_error="".join(
compilation_error="".join(traceback.format_exception( traceback.format_exception(
type(e), e, e.__traceback__)), type(e), e, e.__traceback__)),
runtime_error=None, runtime_error=None,
trace=None, trace=None,
golden_trace=None)) golden_trace=None)
continue
# TODO: Run in parallel.
try: try:
trace = config.run(compiled, golden_trace) trace = config.run(compiled, golden_trace)
except Exception as e: except Exception as e:
results.append( return TestResult(unique_name=test.unique_name,
TestResult(unique_name=test.unique_name,
compilation_error=None, compilation_error=None,
runtime_error="".join(traceback.format_exception( runtime_error="".join(
traceback.format_exception(
type(e), e, e.__traceback__)), type(e), e, e.__traceback__)),
trace=None, trace=None,
golden_trace=None)) golden_trace=None)
continue return TestResult(unique_name=test.unique_name,
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=None, compilation_error=None,
runtime_error=None, runtime_error=None,
trace=trace, trace=clone_trace(trace),
golden_trace=golden_trace)) golden_trace=golden_trace)
queue_sentinel = "QUEUE_SENTINEL"
def run_workers_in_parallel(task_queue: mp.Queue, worker):
NUMBER_OF_PROCESSES = min(int(mp.cpu_count() * 1.1), task_queue.qsize())
# TODO: We've noticed that on certain 2 core machine parallelizing the tests
# makes the llvm backend legacy pass manager 20x slower than using a
# single process. Need to investigate the root cause eventually. This is a
# hack to work around this issue.
if mp.cpu_count() == 2:
NUMBER_OF_PROCESSES = 1
processes = []
for i in range(NUMBER_OF_PROCESSES):
p = mp.Process(target=worker, args=(task_queue, ))
p.start()
processes.append(p)
for i in range(NUMBER_OF_PROCESSES):
task_queue.put(queue_sentinel)
for p in processes:
p.join()
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
"""Invoke the given `Test`'s with the provided `TestConfig`."""
# To run e2e tests in parallel:
# The tests are put into a synchronized queue. Multiple worker processes are
# created. Each worker takes one test at a time from the queue to compile
# and execute it. If the test finishes, whether failed or passed, the result
# of the test is put into a synchronized list which collects the tests
# results from all worker processes.
manager = mp.Manager()
tests_queue = manager.Queue()
sync_results = manager.list()
# This is needed because autograd does not support crossing process
# boundaries.
torch.autograd.set_grad_enabled(False)
for test in tests:
tests_queue.put(test.unique_name)
tests_dict = {test.unique_name: test for test in tests}
def worker(tests_queue: mp.Queue):
for test_name in iter(tests_queue.get, queue_sentinel):
sync_results.append(
compile_and_run_test(tests_dict[test_name], config))
run_workers_in_parallel(tests_queue, worker)
tests_with_results = {result.unique_name for result in sync_results}
all_tests = {test.unique_name for test in tests}
# For processes that are crashed due to compile time or runtime error,
# the error outputs are printed out all together but no TestResult is
# produced when the process crashed.
# TODO: Find a clean way to capture the output from crashed process and
# create more detailed runtime_error for those tests.
aborted_tests = all_tests - tests_with_results
aborted_tests_results = [
TestResult(
unique_name=aborted_test_name,
compilation_error=None,
runtime_error=
"Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
trace=None,
golden_trace=None) for aborted_test_name in aborted_tests
]
results = [result for result in sync_results]
results.extend(aborted_tests_results)
results.sort(key=lambda result: result.unique_name)
return results return results