mirror of https://github.com/llvm/torch-mlir
Make e2e testing parallel
This change makes the e2e testing parallel using the multiprocessing python module.pull/838/head
parent
96fabc0036
commit
2ed90741eb
|
@ -133,9 +133,10 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
|
||||||
([-1], torch.float32, True),
|
([-1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, input, index, value):
|
def forward(self, input, index, value):
|
||||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||||
accumulate=True,
|
value,
|
||||||
unsafe=False)
|
accumulate=True,
|
||||||
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(
|
||||||
|
@ -211,9 +212,10 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
|
||||||
([-1], torch.int64, True),
|
([-1], torch.int64, True),
|
||||||
])
|
])
|
||||||
def forward(self, input, index, value):
|
def forward(self, input, index, value):
|
||||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||||
accumulate=True,
|
value,
|
||||||
unsafe=False)
|
accumulate=True,
|
||||||
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule())
|
||||||
|
|
|
@ -26,7 +26,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Di
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
||||||
Dict['TorchScriptValue',
|
Dict['TorchScriptValue',
|
||||||
|
@ -53,10 +53,6 @@ class TraceItem(NamedTuple):
|
||||||
# public boundary.
|
# public boundary.
|
||||||
Trace = List[TraceItem]
|
Trace = List[TraceItem]
|
||||||
|
|
||||||
# A type shared between the result of `TestConfig.compile` and the input
|
|
||||||
# to `TestConfig.run`. Each backend will likely have a different definition of
|
|
||||||
# this type.
|
|
||||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
|
||||||
|
|
||||||
# Clone all the tensor values.
|
# Clone all the tensor values.
|
||||||
def clone_torch_script_value(v: TorchScriptValue):
|
def clone_torch_script_value(v: TorchScriptValue):
|
||||||
|
@ -76,6 +72,26 @@ def clone_torch_script_value(v: TorchScriptValue):
|
||||||
assert False, "unhandled cloning of TorchScriptValue value type"
|
assert False, "unhandled cloning of TorchScriptValue value type"
|
||||||
|
|
||||||
|
|
||||||
|
# This clone helper is used to work around issues with output tensors when
|
||||||
|
# using multiprocessing module to run tests. The error happens for tests like
|
||||||
|
# ContiguousModule_basic where the output tensor aliases with an input tensor.
|
||||||
|
# When the output tensor is not cloned, the testing trace would be modified for
|
||||||
|
# unknown reason when passed through the shared memory through synchronized
|
||||||
|
# queue for example.
|
||||||
|
# TODO: Figure out the root cause of the failure and fix properly.
|
||||||
|
def clone_trace(trace: Trace) -> Trace:
|
||||||
|
return [
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
output=clone_torch_script_value(item.output))
|
||||||
|
for item in trace
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# A type shared between the result of `TestConfig.compile` and the input
|
||||||
|
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||||
|
# this type.
|
||||||
|
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||||
|
|
||||||
class TestConfig(abc.ABC):
|
class TestConfig(abc.ABC):
|
||||||
"""The interface implemented by backends to run tests.
|
"""The interface implemented by backends to run tests.
|
||||||
|
@ -158,6 +174,7 @@ class TestUtils:
|
||||||
TODO: Figure out how to seed reset properly scoped to just a test case
|
TODO: Figure out how to seed reset properly scoped to just a test case
|
||||||
(such as when running tests in parallel)
|
(such as when running tests in parallel)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
@ -260,40 +277,103 @@ def generate_golden_trace(test: Test) -> Trace:
|
||||||
return trace
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
def compile_and_run_test(test: Test, config: TestConfig) -> Any:
|
||||||
|
try:
|
||||||
|
golden_trace = generate_golden_trace(test)
|
||||||
|
compiled = config.compile(test.program_factory())
|
||||||
|
except Exception as e:
|
||||||
|
return TestResult(unique_name=test.unique_name,
|
||||||
|
compilation_error="".join(
|
||||||
|
traceback.format_exception(
|
||||||
|
type(e), e, e.__traceback__)),
|
||||||
|
runtime_error=None,
|
||||||
|
trace=None,
|
||||||
|
golden_trace=None)
|
||||||
|
try:
|
||||||
|
trace = config.run(compiled, golden_trace)
|
||||||
|
except Exception as e:
|
||||||
|
return TestResult(unique_name=test.unique_name,
|
||||||
|
compilation_error=None,
|
||||||
|
runtime_error="".join(
|
||||||
|
traceback.format_exception(
|
||||||
|
type(e), e, e.__traceback__)),
|
||||||
|
trace=None,
|
||||||
|
golden_trace=None)
|
||||||
|
return TestResult(unique_name=test.unique_name,
|
||||||
|
compilation_error=None,
|
||||||
|
runtime_error=None,
|
||||||
|
trace=clone_trace(trace),
|
||||||
|
golden_trace=golden_trace)
|
||||||
|
|
||||||
|
|
||||||
|
queue_sentinel = "QUEUE_SENTINEL"
|
||||||
|
|
||||||
|
|
||||||
|
def run_workers_in_parallel(task_queue: mp.Queue, worker):
|
||||||
|
NUMBER_OF_PROCESSES = min(int(mp.cpu_count() * 1.1), task_queue.qsize())
|
||||||
|
|
||||||
|
# TODO: We've noticed that on certain 2 core machine parallelizing the tests
|
||||||
|
# makes the llvm backend legacy pass manager 20x slower than using a
|
||||||
|
# single process. Need to investigate the root cause eventually. This is a
|
||||||
|
# hack to work around this issue.
|
||||||
|
if mp.cpu_count() == 2:
|
||||||
|
NUMBER_OF_PROCESSES = 1
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
for i in range(NUMBER_OF_PROCESSES):
|
||||||
|
p = mp.Process(target=worker, args=(task_queue, ))
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
for i in range(NUMBER_OF_PROCESSES):
|
||||||
|
task_queue.put(queue_sentinel)
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||||
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||||
results = []
|
|
||||||
for test in tests:
|
|
||||||
# TODO: Precompile everything in parallel.
|
|
||||||
try:
|
|
||||||
golden_trace = generate_golden_trace(test)
|
|
||||||
compiled = config.compile(test.program_factory())
|
|
||||||
except Exception as e:
|
|
||||||
results.append(
|
|
||||||
TestResult(unique_name=test.unique_name,
|
|
||||||
compilation_error="".join(traceback.format_exception(
|
|
||||||
type(e), e, e.__traceback__)),
|
|
||||||
runtime_error=None,
|
|
||||||
trace=None,
|
|
||||||
golden_trace=None))
|
|
||||||
continue
|
|
||||||
# TODO: Run in parallel.
|
|
||||||
try:
|
|
||||||
trace = config.run(compiled, golden_trace)
|
|
||||||
except Exception as e:
|
|
||||||
results.append(
|
|
||||||
TestResult(unique_name=test.unique_name,
|
|
||||||
compilation_error=None,
|
|
||||||
runtime_error="".join(traceback.format_exception(
|
|
||||||
type(e), e, e.__traceback__)),
|
|
||||||
trace=None,
|
|
||||||
golden_trace=None))
|
|
||||||
continue
|
|
||||||
|
|
||||||
results.append(
|
# To run e2e tests in parallel:
|
||||||
TestResult(unique_name=test.unique_name,
|
# The tests are put into a synchronized queue. Multiple worker processes are
|
||||||
compilation_error=None,
|
# created. Each worker takes one test at a time from the queue to compile
|
||||||
runtime_error=None,
|
# and execute it. If the test finishes, whether failed or passed, the result
|
||||||
trace=trace,
|
# of the test is put into a synchronized list which collects the tests
|
||||||
golden_trace=golden_trace))
|
# results from all worker processes.
|
||||||
|
manager = mp.Manager()
|
||||||
|
tests_queue = manager.Queue()
|
||||||
|
sync_results = manager.list()
|
||||||
|
# This is needed because autograd does not support crossing process
|
||||||
|
# boundaries.
|
||||||
|
torch.autograd.set_grad_enabled(False)
|
||||||
|
for test in tests:
|
||||||
|
tests_queue.put(test.unique_name)
|
||||||
|
|
||||||
|
tests_dict = {test.unique_name: test for test in tests}
|
||||||
|
|
||||||
|
def worker(tests_queue: mp.Queue):
|
||||||
|
for test_name in iter(tests_queue.get, queue_sentinel):
|
||||||
|
sync_results.append(
|
||||||
|
compile_and_run_test(tests_dict[test_name], config))
|
||||||
|
|
||||||
|
run_workers_in_parallel(tests_queue, worker)
|
||||||
|
tests_with_results = {result.unique_name for result in sync_results}
|
||||||
|
all_tests = {test.unique_name for test in tests}
|
||||||
|
# For processes that are crashed due to compile time or runtime error,
|
||||||
|
# the error outputs are printed out all together but no TestResult is
|
||||||
|
# produced when the process crashed.
|
||||||
|
# TODO: Find a clean way to capture the output from crashed process and
|
||||||
|
# create more detailed runtime_error for those tests.
|
||||||
|
aborted_tests = all_tests - tests_with_results
|
||||||
|
aborted_tests_results = [
|
||||||
|
TestResult(
|
||||||
|
unique_name=aborted_test_name,
|
||||||
|
compilation_error=None,
|
||||||
|
runtime_error=
|
||||||
|
"Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
|
||||||
|
trace=None,
|
||||||
|
golden_trace=None) for aborted_test_name in aborted_tests
|
||||||
|
]
|
||||||
|
results = [result for result in sync_results]
|
||||||
|
results.extend(aborted_tests_results)
|
||||||
|
results.sort(key=lambda result: result.unique_name)
|
||||||
return results
|
return results
|
||||||
|
|
Loading…
Reference in New Issue