# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. """ # End-to-end testing framework for TorchScript. For the purposes of this framework, "end to end" means the first "end" is a `torch.nn.Module`, and the second "end" is execution. ## Architecture A program for this testing framework is considered to be a `torch.nn.Module`, which has a public interface consisting of its methods and instance attributes. A test in the framework consists conceputally of a list of calls into the methods of a module (TODO: extend to instance attributes). It is expected that the outputs match between the program run on a backend (controlled by a TestConfig) and a golden trace obtained by running on native Torch (without compiling or TorchScript'ing). """ import abc from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict import sys import traceback import torch import torch.multiprocessing as mp TorchScriptValue = Union[int, float, List['TorchScriptValue'], Dict['TorchScriptValue', 'TorchScriptValue'], torch.Tensor] class TraceItem(NamedTuple): # The externally visible symbol name that is called. # For example `"forward"` or `"submodule.forward"`. symbol: str # The inputs to the call. inputs: List[TorchScriptValue] # The output from the call. # In Python, there is only one output from a function. It might be a tuple # in case of "multiple results". # Sometimes this field is treated as golden outputs from a test. # Sometimes this field is treated as ignored, such as the input trace # provided to `TestConfig.run`. output: TorchScriptValue # A trace of invocations to the program. # This is an ordered sequence of external invocations to a program's # public boundary. Trace = List[TraceItem] # Clone all the tensor values. def clone_torch_script_value(v: TorchScriptValue): if isinstance(v, torch.Tensor): return v.clone() if isinstance(v, tuple): return tuple(clone_torch_script_value(field) for field in v) if isinstance(v, list): return [clone_torch_script_value(item) for item in v] if isinstance(v, dict): return { clone_torch_script_value(key): clone_torch_script_value(val) for key, val in v.items() } if isinstance(v, float) or isinstance(v, int) or isinstance(v, str): return v assert False, "unhandled cloning of TorchScriptValue value type" # This clone helper is used to work around issues with output tensors when # using multiprocessing module to run tests. The error happens for tests like # ContiguousModule_basic where the output tensor aliases with an input tensor. # When the output tensor is not cloned, the testing trace would be modified for # unknown reason when passed through the shared memory through synchronized # queue for example. # TODO: Figure out the root cause of the failure and fix properly. def clone_trace(trace: Trace) -> Trace: return [ TraceItem(symbol=item.symbol, inputs=item.inputs, output=clone_torch_script_value(item.output)) for item in trace ] # A type shared between the result of `TestConfig.compile` and the input # to `TestConfig.run`. Each backend will likely have a different definition of # this type. CompiledArtifact = TypeVar('CompiledArtifact') class TestConfig(abc.ABC): """The interface implemented by backends to run tests. The testing framework expects to be able to call `compile` to compile a torch.nn.Module, and then pass the compiled artifact to `run` to run it. Note that the definition of "compiled artifact" here is quite loose, and this interface allows for many different use cases besides simple testing. For example, this interface can be overridden to be a "data collector" to gather information across all the test cases. For example, a compiler backend could override "compile" to just return some IR at a useful intermediate abstraction level (rather than the final compiled artifact), and then have "run" save this intermediate IR + the trace as input to some lower-level software stack's testing format. The set of TestConfig's is expected to be pluggable and provided by users to suit their own needs. We provide a few configs out of the box in the `configs` submodule of this package, but those are intended to be for basic inspiration and enough for our own testing. Backends to torch-mlir will likely have more elaborate TestConfig's, such as `compile` being "compile for such-and-such DSP with these vectorization cost model flags" and `run` being "connect to Android phone with device ID 1234 and upload a program to run on it's DSP core, and also set power throttling settings to 'performance'". That is also why this class is not called "backend", as it encapsulates potentially many specific details of the test configuration process as well. There isn't a general way to disentangle test configuration from the compile/run process specific to a logical backend, since each backend (compiler backend and runtime target) will have an arbitrarily wild and wonderful set of possible configurations that we cannot predict. """ # This is not a frontend-lowered module, to allow various testing at the PyTorch level. # We can have a helper class LinalgOnTensorsBackendTestConfig which does that. @abc.abstractmethod def compile(self, program: torch.nn.Module) -> CompiledArtifact: """Compile the provided torch.nn.Module into a compiled artifact""" pass # Any should match result of `compile`. @abc.abstractmethod def run(self, artifact: CompiledArtifact, trace: Trace) -> Trace: """Run the compiled artifact produced by `compile`. The backend should load the compiled artifact and call the symbol names listed in `trace` with their respective inputs (the outputs of `trace` should be ignored). A new identical trace with outputs populated should be returned. This method should assume that `artifact` is being shared with multiple parallel invocations of `run`, and so it should not be mutated. This property is typicaly trivially satisfied for a true "compiled artifact", but some backends don't directly involve a compiled artifact per se (like a backend for which `CompiledArtifact` is `torch.nn.Module` and `run` just invokes the torch.nn.Module itself) Args: artifact: a compiled artifact produced by `compile`. trace: The external invocations to stimulate the module. Returns: A trace with outputs recorded according to the results of running on this backend. """ pass # Utilities for common testing trace generation. # Also, resets the random seed for reproducibility. # TODO: If generating in parallel, how to have manual_seed be local? class TestUtils: """Utilities for executing a test. Test cases are provided an instance of this class to make test cases more succinct. For reproducibility, this class also resets the random seed. TODO: Figure out how to seed reset properly scoped to just a test case (such as when running tests in parallel) """ def __init__(self): torch.manual_seed(0) # TODO: Add zeros/ones/etc. as convenient. def rand(self, *sizes, low=0.0, high=1.0): return torch.empty(sizes).uniform_(low, high) def randint(self, *sizes, low=0, high=10): return torch.randint(low, high, sizes) def nans(self, *sizes): vals = torch.empty(sizes) vals[...] = torch.nan return vals class Test(NamedTuple): """A description of a test as produced by the test frontend. """ # Stable name for error reporting. # # This name's stability is also useful for backend, which want to # generate their own lower-level test suites based on this framework. # # It is expected that those backends will need additional # metadata to describe their test configurations, so having a unique # key to keep that information associated is important. unique_name: str # A callable which produces the module under test. # This is a callable to allow lazily creating the module. program_factory: Callable[[], torch.nn.Module] # A callable which provides external stimuli to the module. # The first parameter is a torch.nn.Module (or a `_Tracer` wrapping that # module, actually). # The secon parameter is a `TestUtils` instance for convenience. program_invoker: Callable[[Any, TestUtils], None] class TestResult(NamedTuple): # Stable unique name for error reporting and test suite configuration. # # Tests frequently need some additional data (such as expected pass/fail # status, desired test configurations, etc.), and this gives a key to # associate to. This avoids extending this class arbitrarily for every # possible requirement from the test framework. # # This name is also useful for backends that are generating their own # lower-level test suites from this framework for the same reasons, though # those reasons are stronger because we cannot simply extend this # class. unique_name: str # Should match Test.unique_name for corresponding test. # If compilation failed, a string describing the failure. # If this is not None, then the `trace` and `golden_trace` fields are None, # and vice-versa. compilation_error: Optional[str] # If runtime failed, a string describing the failure. # If this is not None, then the `trace` and `golden_trace` fields are None, # and vice-versa. runtime_error: Optional[str] # The trace produced by the backend. trace: Optional[Trace] # The golden trace which `trace` is expected to match. golden_trace: Optional[Trace] class _Tracer: """Wrapper around a `torch.nn.Module` that records calls into it. The inputs and outputs of each call are recorded in a Trace. Recursive property accesses are also traced. """ def __init__(self, wrapped, property_base_path: List[str], trace: Trace): self.__wrapped__ = wrapped self.__trace__ = trace self.__property_base_path__ = property_base_path def __call__(self, *args, **kwargs): # Clone the inputs to capture the original tensors values. This is # needed because inplace mutation might happen to the input tensors. inputs = [clone_torch_script_value(arg) for arg in args] output = self.__wrapped__(*args, **kwargs) self.__trace__.append( TraceItem(symbol=".".join(self.__property_base_path__), inputs=inputs, output=output)) return output def __getattr__(self, name): return _Tracer(getattr(self.__wrapped__, name), self.__property_base_path__ + [name], self.__trace__) def generate_golden_trace(test: Test) -> Trace: """Generate a trace with the original program. If the original program is deterministic, then this the produced trace is suitable as a golden trace to compare against. """ trace = [] tracer = _Tracer(test.program_factory(), [], trace) test.program_invoker(tracer, TestUtils()) return trace def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: try: golden_trace = generate_golden_trace(test) if verbose: print(f"Compiling {test.unique_name}...", file=sys.stderr) compiled = config.compile(test.program_factory()) except Exception as e: return TestResult(unique_name=test.unique_name, compilation_error="".join( traceback.format_exception( type(e), e, e.__traceback__)), runtime_error=None, trace=None, golden_trace=None) try: if verbose: print(f"Running {test.unique_name}...", file=sys.stderr) trace = config.run(compiled, golden_trace) except Exception as e: return TestResult(unique_name=test.unique_name, compilation_error=None, runtime_error="".join( traceback.format_exception( type(e), e, e.__traceback__)), trace=None, golden_trace=None) return TestResult(unique_name=test.unique_name, compilation_error=None, runtime_error=None, trace=clone_trace(trace), golden_trace=golden_trace) queue_sentinel = "QUEUE_SENTINEL" def run_workers_in_parallel(task_queue: mp.Queue, worker, num_processes: int): processes = [] for i in range(num_processes): p = mp.get_context("fork").Process(target=worker, args=(task_queue, )) p.start() processes.append(p) for i in range(num_processes): task_queue.put(queue_sentinel) for p in processes: p.join() def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]: """Invoke the given `Test`'s with the provided `TestConfig`.""" num_processes = min(int(mp.cpu_count() * 1.1), len(tests)) # TODO: We've noticed that on certain 2 core machine parallelizing the tests # makes the llvm backend legacy pass manager 20x slower than using a # single process. Need to investigate the root cause eventually. This is a # hack to work around this issue. # Also our multiprocessing implementation is not the most efficient, so # the benefit at core count 2 is probably not worth it anyway. if mp.cpu_count() == 2: num_processes = 1 # Sort the tests to make output nicer. tests = list(sorted(tests, key=lambda t: t.unique_name)) # TODO: If num_processes == 1, then run without any of the multiprocessing # machinery. In theory it should work, but any crash in the testing process # seems to cause a cascade of failures resulting in undecipherable error # messages. if num_processes == 1 or sequential: return [compile_and_run_test(test, config, verbose) for test in tests] # To run e2e tests in parallel: # The tests are put into a synchronized queue. Multiple worker processes are # created. Each worker takes one test at a time from the queue to compile # and execute it. If the test finishes, whether failed or passed, the result # of the test is put into a synchronized list which collects the tests # results from all worker processes. manager = mp.Manager() tests_queue = manager.Queue() sync_results = manager.list() # This is needed because autograd does not support crossing process # boundaries. torch.autograd.set_grad_enabled(False) for test in tests: tests_queue.put(test.unique_name) tests_dict = {test.unique_name: test for test in tests} def worker(tests_queue: mp.Queue): for test_name in iter(tests_queue.get, queue_sentinel): sync_results.append( compile_and_run_test(tests_dict[test_name], config)) run_workers_in_parallel(tests_queue, worker, num_processes) tests_with_results = {result.unique_name for result in sync_results} all_tests = {test.unique_name for test in tests} # For processes that are crashed due to compile time or runtime error, # the error outputs are printed out all together but no TestResult is # produced when the process crashed. # TODO: Find a clean way to capture the output from crashed process and # create more detailed runtime_error for those tests. aborted_tests = all_tests - tests_with_results aborted_tests_results = [ TestResult( unique_name=aborted_test_name, compilation_error=None, runtime_error= "Testing process terminated. Either the compiler crashed or the compiled code crashed at runtime.\n", trace=None, golden_trace=None) for aborted_test_name in aborted_tests ] results = [result for result in sync_results] results.extend(aborted_tests_results) results.sort(key=lambda result: result.unique_name) return results