torch-mlir/python/torch_mlir_e2e_test/framework.py

398 lines
16 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
# End-to-end testing framework for TorchScript.
For the purposes of this framework, "end to end" means the first "end" is
a `torch.nn.Module`, and the second "end" is execution.
## Architecture
A program for this testing framework is considered to be a `torch.nn.Module`,
which has a public interface consisting of its methods and instance attributes.
A test in the framework consists conceputally of a list of calls into
the methods of a module (TODO: extend to instance attributes). It is expected
that the outputs match between the program run on a backend (controlled by
a TestConfig) and a golden trace obtained by running on native Torch (without
compiling or TorchScript'ing).
"""
import abc
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict
import sys
import traceback
import torch
import torch.multiprocessing as mp
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
Dict['TorchScriptValue',
'TorchScriptValue'], torch.Tensor]
class TraceItem(NamedTuple):
# The externally visible symbol name that is called.
# For example `"forward"` or `"submodule.forward"`.
symbol: str
# The inputs to the call.
inputs: List[TorchScriptValue]
# The output from the call.
# In Python, there is only one output from a function. It might be a tuple
# in case of "multiple results".
# Sometimes this field is treated as golden outputs from a test.
# Sometimes this field is treated as ignored, such as the input trace
# provided to `TestConfig.run`.
output: TorchScriptValue
# A trace of invocations to the program.
# This is an ordered sequence of external invocations to a program's
# public boundary.
Trace = List[TraceItem]
# Clone all the tensor values.
def clone_torch_script_value(v: TorchScriptValue):
if isinstance(v, torch.Tensor):
return v.clone()
if isinstance(v, tuple):
return tuple(clone_torch_script_value(field) for field in v)
if isinstance(v, list):
return [clone_torch_script_value(item) for item in v]
if isinstance(v, dict):
return {
clone_torch_script_value(key): clone_torch_script_value(val)
for key, val in v.items()
}
if isinstance(v, float) or isinstance(v, int) or isinstance(v, str):
return v
assert False, "unhandled cloning of TorchScriptValue value type"
# This clone helper is used to work around issues with output tensors when
# using multiprocessing module to run tests. The error happens for tests like
# ContiguousModule_basic where the output tensor aliases with an input tensor.
# When the output tensor is not cloned, the testing trace would be modified for
# unknown reason when passed through the shared memory through synchronized
# queue for example.
# TODO: Figure out the root cause of the failure and fix properly.
def clone_trace(trace: Trace) -> Trace:
return [
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=clone_torch_script_value(item.output))
for item in trace
]
# A type shared between the result of `TestConfig.compile` and the input
# to `TestConfig.run`. Each backend will likely have a different definition of
# this type.
CompiledArtifact = TypeVar('CompiledArtifact')
class TestConfig(abc.ABC):
"""The interface implemented by backends to run tests.
The testing framework expects to be able to call `compile` to compile
a torch.nn.Module, and then pass the compiled artifact to `run` to run it.
Note that the definition of "compiled artifact" here is quite loose, and
this interface allows for many different use cases besides simple testing.
For example, this interface can be overridden to be a "data collector"
to gather information across all the test cases. For example,
a compiler backend could override "compile" to just return some IR at a
useful intermediate abstraction level (rather than the final compiled
artifact), and then have "run" save this intermediate IR + the trace as
input to some lower-level software stack's testing format.
The set of TestConfig's is expected to be pluggable and provided by
users to suit their own needs. We provide a few configs out of the box
in the `configs` submodule of this package, but those are intended
to be for basic inspiration and enough for our own testing.
Backends to torch-mlir will likely have more elaborate TestConfig's, such
as `compile` being "compile for such-and-such DSP with these vectorization
cost model flags" and `run` being "connect to Android phone with
device ID 1234 and upload a program to run on it's DSP core, and also set
power throttling settings to 'performance'".
That is also why this class is not called "backend", as it
encapsulates potentially many specific details of the test configuration
process as well. There isn't a general way to disentangle test configuration
from the compile/run process specific to a logical backend, since each
backend (compiler backend and runtime target) will have an arbitrarily
wild and wonderful set of possible configurations that we cannot predict.
"""
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
# We can have a helper class LinalgOnTensorsBackendTestConfig which does that.
@abc.abstractmethod
def compile(self, program: torch.nn.Module) -> CompiledArtifact:
"""Compile the provided torch.nn.Module into a compiled artifact"""
pass
# Any should match result of `compile`.
@abc.abstractmethod
def run(self, artifact: CompiledArtifact, trace: Trace) -> Trace:
"""Run the compiled artifact produced by `compile`.
The backend should load the compiled artifact and call the
symbol names listed in `trace` with their respective inputs (the outputs
of `trace` should be ignored). A new identical trace with outputs
populated should be returned.
This method should assume that `artifact` is being shared with
multiple parallel invocations of `run`, and so it should not be mutated.
This property is typicaly trivially satisfied for a true
"compiled artifact", but some backends don't directly involve a
compiled artifact per se (like a backend for which `CompiledArtifact` is
`torch.nn.Module` and `run` just invokes the torch.nn.Module itself)
Args:
artifact: a compiled artifact produced by `compile`.
trace: The external invocations to stimulate the module.
Returns:
A trace with outputs recorded according to the results of running
on this backend.
"""
pass
# Utilities for common testing trace generation.
# Also, resets the random seed for reproducibility.
# TODO: If generating in parallel, how to have manual_seed be local?
class TestUtils:
"""Utilities for executing a test.
Test cases are provided an instance of this class to make test cases
more succinct.
For reproducibility, this class also resets the random seed.
TODO: Figure out how to seed reset properly scoped to just a test case
(such as when running tests in parallel)
"""
def __init__(self):
torch.manual_seed(0)
# TODO: Add zeros/ones/etc. as convenient.
def rand(self, *sizes, low=0.0, high=1.0):
return torch.empty(sizes).uniform_(low, high)
def randint(self, *sizes, low=0, high=10):
return torch.randint(low, high, sizes)
def nans(self, *sizes):
vals = torch.empty(sizes)
vals[...] = torch.nan
return vals
class Test(NamedTuple):
"""A description of a test as produced by the test frontend.
"""
# Stable name for error reporting.
#
# This name's stability is also useful for backend, which want to
# generate their own lower-level test suites based on this framework.
#
# It is expected that those backends will need additional
# metadata to describe their test configurations, so having a unique
# key to keep that information associated is important.
unique_name: str
# A callable which produces the module under test.
# This is a callable to allow lazily creating the module.
program_factory: Callable[[], torch.nn.Module]
# A callable which provides external stimuli to the module.
# The first parameter is a torch.nn.Module (or a `_Tracer` wrapping that
# module, actually).
# The secon parameter is a `TestUtils` instance for convenience.
program_invoker: Callable[[Any, TestUtils], None]
class TestResult(NamedTuple):
# Stable unique name for error reporting and test suite configuration.
#
# Tests frequently need some additional data (such as expected pass/fail
# status, desired test configurations, etc.), and this gives a key to
# associate to. This avoids extending this class arbitrarily for every
# possible requirement from the test framework.
#
# This name is also useful for backends that are generating their own
# lower-level test suites from this framework for the same reasons, though
# those reasons are stronger because we cannot simply extend this
# class.
unique_name: str # Should match Test.unique_name for corresponding test.
# If compilation failed, a string describing the failure.
# If this is not None, then the `trace` and `golden_trace` fields are None,
# and vice-versa.
compilation_error: Optional[str]
# If runtime failed, a string describing the failure.
# If this is not None, then the `trace` and `golden_trace` fields are None,
# and vice-versa.
runtime_error: Optional[str]
# The trace produced by the backend.
trace: Optional[Trace]
# The golden trace which `trace` is expected to match.
golden_trace: Optional[Trace]
class _Tracer:
"""Wrapper around a `torch.nn.Module` that records calls into it.
The inputs and outputs of each call are recorded in a Trace. Recursive
property accesses are also traced.
"""
def __init__(self, wrapped, property_base_path: List[str], trace: Trace):
self.__wrapped__ = wrapped
self.__trace__ = trace
self.__property_base_path__ = property_base_path
def __call__(self, *args, **kwargs):
# Clone the inputs to capture the original tensors values. This is
# needed because inplace mutation might happen to the input tensors.
inputs = [clone_torch_script_value(arg) for arg in args]
output = self.__wrapped__(*args, **kwargs)
self.__trace__.append(
TraceItem(symbol=".".join(self.__property_base_path__),
inputs=inputs,
output=output))
return output
def __getattr__(self, name):
return _Tracer(getattr(self.__wrapped__, name),
self.__property_base_path__ + [name], self.__trace__)
def generate_golden_trace(test: Test) -> Trace:
"""Generate a trace with the original program.
If the original program is deterministic, then this the produced trace is
suitable as a golden trace to compare against.
"""
trace = []
tracer = _Tracer(test.program_factory(), [], trace)
test.program_invoker(tracer, TestUtils())
return trace
def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any:
try:
golden_trace = generate_golden_trace(test)
if verbose:
print(f"Compiling {test.unique_name}...", file=sys.stderr)
compiled = config.compile(test.program_factory())
except Exception as e:
return TestResult(unique_name=test.unique_name,
compilation_error="".join(
traceback.format_exception(
type(e), e, e.__traceback__)),
runtime_error=None,
trace=None,
golden_trace=None)
try:
if verbose:
print(f"Running {test.unique_name}...", file=sys.stderr)
trace = config.run(compiled, golden_trace)
except Exception as e:
return TestResult(unique_name=test.unique_name,
compilation_error=None,
runtime_error="".join(
traceback.format_exception(
type(e), e, e.__traceback__)),
trace=None,
golden_trace=None)
return TestResult(unique_name=test.unique_name,
compilation_error=None,
runtime_error=None,
trace=clone_trace(trace),
golden_trace=golden_trace)
queue_sentinel = "QUEUE_SENTINEL"
def run_workers_in_parallel(task_queue: mp.Queue, worker, num_processes: int):
processes = []
for i in range(num_processes):
p = mp.get_context("fork").Process(target=worker, args=(task_queue, ))
p.start()
processes.append(p)
for i in range(num_processes):
task_queue.put(queue_sentinel)
for p in processes:
p.join()
def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]:
"""Invoke the given `Test`'s with the provided `TestConfig`."""
num_processes = min(int(mp.cpu_count() * 1.1), len(tests))
# TODO: We've noticed that on certain 2 core machine parallelizing the tests
# makes the llvm backend legacy pass manager 20x slower than using a
# single process. Need to investigate the root cause eventually. This is a
# hack to work around this issue.
# Also our multiprocessing implementation is not the most efficient, so
# the benefit at core count 2 is probably not worth it anyway.
if mp.cpu_count() == 2:
num_processes = 1
# TODO: If num_processes == 1, then run without any of the multiprocessing
# machinery. In theory it should work, but any crash in the testing process
# seems to cause a cascade of failures resulting in undecipherable error
# messages.
if num_processes == 1 or sequential:
return [compile_and_run_test(test, config, verbose) for test in tests]
# Sort the tests to make output nicer.
tests = list(sorted(tests, key=lambda t: t.unique_name))
# To run e2e tests in parallel:
# The tests are put into a synchronized queue. Multiple worker processes are
# created. Each worker takes one test at a time from the queue to compile
# and execute it. If the test finishes, whether failed or passed, the result
# of the test is put into a synchronized list which collects the tests
# results from all worker processes.
manager = mp.Manager()
tests_queue = manager.Queue()
sync_results = manager.list()
# This is needed because autograd does not support crossing process
# boundaries.
torch.autograd.set_grad_enabled(False)
for test in tests:
tests_queue.put(test.unique_name)
tests_dict = {test.unique_name: test for test in tests}
def worker(tests_queue: mp.Queue):
for test_name in iter(tests_queue.get, queue_sentinel):
sync_results.append(
compile_and_run_test(tests_dict[test_name], config))
run_workers_in_parallel(tests_queue, worker, num_processes)
tests_with_results = {result.unique_name for result in sync_results}
all_tests = {test.unique_name for test in tests}
# For processes that are crashed due to compile time or runtime error,
# the error outputs are printed out all together but no TestResult is
# produced when the process crashed.
# TODO: Find a clean way to capture the output from crashed process and
# create more detailed runtime_error for those tests.
aborted_tests = all_tests - tests_with_results
aborted_tests_results = [
TestResult(
unique_name=aborted_test_name,
compilation_error=None,
runtime_error=
"Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n",
trace=None,
golden_trace=None) for aborted_test_name in aborted_tests
]
results = [result for result in sync_results]
results.extend(aborted_tests_results)
results.sort(key=lambda result: result.unique_name)
return results