mirror of https://github.com/llvm/torch-mlir
Make e2e testing parallel
This change makes the e2e testing parallel using the multiprocessing python module.pull/838/head
parent
96fabc0036
commit
2ed90741eb
|
@ -133,7 +133,8 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
|
|||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
value,
|
||||
accumulate=True,
|
||||
unsafe=False)
|
||||
|
||||
|
@ -211,7 +212,8 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
|
|||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
value,
|
||||
accumulate=True,
|
||||
unsafe=False)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Di
|
|||
import traceback
|
||||
|
||||
import torch
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
||||
Dict['TorchScriptValue',
|
||||
|
@ -53,10 +53,6 @@ class TraceItem(NamedTuple):
|
|||
# public boundary.
|
||||
Trace = List[TraceItem]
|
||||
|
||||
# A type shared between the result of `TestConfig.compile` and the input
|
||||
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||
# this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
# Clone all the tensor values.
|
||||
def clone_torch_script_value(v: TorchScriptValue):
|
||||
|
@ -76,6 +72,26 @@ def clone_torch_script_value(v: TorchScriptValue):
|
|||
assert False, "unhandled cloning of TorchScriptValue value type"
|
||||
|
||||
|
||||
# This clone helper is used to work around issues with output tensors when
|
||||
# using multiprocessing module to run tests. The error happens for tests like
|
||||
# ContiguousModule_basic where the output tensor aliases with an input tensor.
|
||||
# When the output tensor is not cloned, the testing trace would be modified for
|
||||
# unknown reason when passed through the shared memory through synchronized
|
||||
# queue for example.
|
||||
# TODO: Figure out the root cause of the failure and fix properly.
|
||||
def clone_trace(trace: Trace) -> Trace:
|
||||
return [
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
output=clone_torch_script_value(item.output))
|
||||
for item in trace
|
||||
]
|
||||
|
||||
|
||||
# A type shared between the result of `TestConfig.compile` and the input
|
||||
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||
# this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
class TestConfig(abc.ABC):
|
||||
"""The interface implemented by backends to run tests.
|
||||
|
@ -158,6 +174,7 @@ class TestUtils:
|
|||
TODO: Figure out how to seed reset properly scoped to just a test case
|
||||
(such as when running tests in parallel)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
@ -260,40 +277,103 @@ def generate_golden_trace(test: Test) -> Trace:
|
|||
return trace
|
||||
|
||||
|
||||
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||
results = []
|
||||
for test in tests:
|
||||
# TODO: Precompile everything in parallel.
|
||||
def compile_and_run_test(test: Test, config: TestConfig) -> Any:
|
||||
try:
|
||||
golden_trace = generate_golden_trace(test)
|
||||
compiled = config.compile(test.program_factory())
|
||||
except Exception as e:
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error="".join(traceback.format_exception(
|
||||
return TestResult(unique_name=test.unique_name,
|
||||
compilation_error="".join(
|
||||
traceback.format_exception(
|
||||
type(e), e, e.__traceback__)),
|
||||
runtime_error=None,
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
continue
|
||||
# TODO: Run in parallel.
|
||||
golden_trace=None)
|
||||
try:
|
||||
trace = config.run(compiled, golden_trace)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
return TestResult(unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error="".join(traceback.format_exception(
|
||||
runtime_error="".join(
|
||||
traceback.format_exception(
|
||||
type(e), e, e.__traceback__)),
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
continue
|
||||
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
golden_trace=None)
|
||||
return TestResult(unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error=None,
|
||||
trace=trace,
|
||||
golden_trace=golden_trace))
|
||||
trace=clone_trace(trace),
|
||||
golden_trace=golden_trace)
|
||||
|
||||
|
||||
queue_sentinel = "QUEUE_SENTINEL"
|
||||
|
||||
|
||||
def run_workers_in_parallel(task_queue: mp.Queue, worker):
|
||||
NUMBER_OF_PROCESSES = min(int(mp.cpu_count() * 1.1), task_queue.qsize())
|
||||
|
||||
# TODO: We've noticed that on certain 2 core machine parallelizing the tests
|
||||
# makes the llvm backend legacy pass manager 20x slower than using a
|
||||
# single process. Need to investigate the root cause eventually. This is a
|
||||
# hack to work around this issue.
|
||||
if mp.cpu_count() == 2:
|
||||
NUMBER_OF_PROCESSES = 1
|
||||
|
||||
processes = []
|
||||
for i in range(NUMBER_OF_PROCESSES):
|
||||
p = mp.Process(target=worker, args=(task_queue, ))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
for i in range(NUMBER_OF_PROCESSES):
|
||||
task_queue.put(queue_sentinel)
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||
|
||||
# To run e2e tests in parallel:
|
||||
# The tests are put into a synchronized queue. Multiple worker processes are
|
||||
# created. Each worker takes one test at a time from the queue to compile
|
||||
# and execute it. If the test finishes, whether failed or passed, the result
|
||||
# of the test is put into a synchronized list which collects the tests
|
||||
# results from all worker processes.
|
||||
manager = mp.Manager()
|
||||
tests_queue = manager.Queue()
|
||||
sync_results = manager.list()
|
||||
# This is needed because autograd does not support crossing process
|
||||
# boundaries.
|
||||
torch.autograd.set_grad_enabled(False)
|
||||
for test in tests:
|
||||
tests_queue.put(test.unique_name)
|
||||
|
||||
tests_dict = {test.unique_name: test for test in tests}
|
||||
|
||||
def worker(tests_queue: mp.Queue):
|
||||
for test_name in iter(tests_queue.get, queue_sentinel):
|
||||
sync_results.append(
|
||||
compile_and_run_test(tests_dict[test_name], config))
|
||||
|
||||
run_workers_in_parallel(tests_queue, worker)
|
||||
tests_with_results = {result.unique_name for result in sync_results}
|
||||
all_tests = {test.unique_name for test in tests}
|
||||
# For processes that are crashed due to compile time or runtime error,
|
||||
# the error outputs are printed out all together but no TestResult is
|
||||
# produced when the process crashed.
|
||||
# TODO: Find a clean way to capture the output from crashed process and
|
||||
# create more detailed runtime_error for those tests.
|
||||
aborted_tests = all_tests - tests_with_results
|
||||
aborted_tests_results = [
|
||||
TestResult(
|
||||
unique_name=aborted_test_name,
|
||||
compilation_error=None,
|
||||
runtime_error=
|
||||
"Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
|
||||
trace=None,
|
||||
golden_trace=None) for aborted_test_name in aborted_tests
|
||||
]
|
||||
results = [result for result in sync_results]
|
||||
results.extend(aborted_tests_results)
|
||||
results.sort(key=lambda result: result.unique_name)
|
||||
return results
|
||||
|
|
Loading…
Reference in New Issue