diff --git a/build_tools/torchscript_e2e_heavydep_tests/main.py b/build_tools/torchscript_e2e_heavydep_tests/main.py index eed396ba0..d05f5c12a 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/main.py +++ b/build_tools/torchscript_e2e_heavydep_tests/main.py @@ -4,14 +4,8 @@ # Also available under a BSD-style license. See LICENSE. import argparse -import os -import pickle -import torch - -from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate_golden_trace -from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations +from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to from . import hf_sequence_classification @@ -25,22 +19,7 @@ def _get_argparse(): def main(): args = _get_argparse().parse_args() - serializable_tests = [] - for test in GLOBAL_TEST_REGISTRY: - trace = generate_golden_trace(test) - module = torch.jit.script(test.program_factory()) - torchscript_module_bytes = module.save_to_buffer({ - "annotations.pkl": - pickle.dumps(extract_serializable_annotations(module)) - }) - serializable_tests.append( - SerializableTest(unique_name=test.unique_name, - program=torchscript_module_bytes, - trace=trace)) - for test in serializable_tests: - with open(os.path.join(args.output_dir, f"{test.unique_name}.pkl"), - "wb") as f: - pickle.dump(test, f) + serialize_all_tests_to(args.output_dir) if __name__ == "__main__": diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 864608a99..7edad130d 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -12,6 +12,7 @@ import sys from torch_mlir_e2e_test.torchscript.framework import TestConfig, run_tests from torch_mlir_e2e_test.torchscript.reporting import report_results from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY +from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_from # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( @@ -72,13 +73,10 @@ for more information on building these artifacts. def main(): args = _get_argparse().parse_args() - all_tests = list(GLOBAL_TEST_REGISTRY) if args.serialized_test_dir: - for root, dirs, files in os.walk(args.serialized_test_dir): - for filename in files: - with open(os.path.join(root, filename), 'rb') as f: - all_tests.append(pickle.load(f).as_test()) - all_test_unique_names = set(test.unique_name for test in all_tests) + deserialize_all_tests_from(args.serialized_test_dir) + all_test_unique_names = set( + test.unique_name for test in GLOBAL_TEST_REGISTRY) # Find the selected config. if args.config == 'refbackend': @@ -117,7 +115,7 @@ def main(): # Find the selected tests, and emit a diagnostic if none are found. tests = [ - test for test in all_tests + test for test in GLOBAL_TEST_REGISTRY if re.match(args.filter, test.unique_name) ] if len(tests) == 0: @@ -125,7 +123,7 @@ def main(): f'ERROR: the provided filter {args.filter!r} does not match any tests' ) print('The available tests are:') - for test in all_tests: + for test in GLOBAL_TEST_REGISTRY: print(test.unique_name) sys.exit(1) diff --git a/python/torch_mlir_e2e_test/torchscript/annotations.py b/python/torch_mlir_e2e_test/torchscript/annotations.py index bb578dd91..5a07b306e 100644 --- a/python/torch_mlir_e2e_test/torchscript/annotations.py +++ b/python/torch_mlir_e2e_test/torchscript/annotations.py @@ -68,62 +68,3 @@ def annotate_args(annotations: List[Optional[ArgAnnotation]]): return fn return decorator - - -class SerializableMethodAnnotation(NamedTuple): - method_name: str - export: Optional[bool] - arg_annotations: Optional[List[ArgAnnotation]] - - -class SerializableModuleAnnotations(NamedTuple): - method_annotations: List[SerializableMethodAnnotation] - submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]] - - -def extract_serializable_annotations( - module: torch.nn.Module) -> SerializableModuleAnnotations: - module_annotations = SerializableModuleAnnotations( - method_annotations=[], submodule_annotations=[]) - # Extract information on methods. - for method_name, method in module.__dict__.items(): - # See if it is a method. - if not callable(method): - continue - export = None - arg_annotations = None - if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME): - export = method._torch_mlir_export - if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME): - arg_annotations = method._torch_mlir_arg_annotations - if export is not None and arg_annotations is not None: - module_annotations.method_annotations.append( - SerializableMethodAnnotation(method_name=method_name, - export=export, - arg_annotations=arg_annotations)) - - # Recurse. - for name, child in module.named_children(): - annotations = extract_serializable_annotations(child) - module_annotations.submodule_annotations.append((name, annotations)) - return module_annotations - - -def apply_serializable_annotations(module: torch.nn.Module, - annotations: SerializableModuleAnnotations): - # Apply annotations to methods. - for method_annotation in annotations.method_annotations: - # Imitate use of the decorators to keep a source of truth there. - if method_annotation.export is not None: - setattr(module, method_annotation.method_name, - export(getattr(module, method_annotation.method_name))) - if method_annotation.arg_annotations is not None: - setattr( - module, method_annotation.method_name, - annotate_args(method_annotation.arg_annotations)(getattr( - module, method_annotation.method_name))) - - # Recurse. - for name, submodule_annotations in annotations.submodule_annotations: - child = getattr(module, name) - apply_serializable_annotations(child, submodule_annotations) diff --git a/python/torch_mlir_e2e_test/torchscript/framework.py b/python/torch_mlir_e2e_test/torchscript/framework.py index b11318e14..4dd3bdea9 100644 --- a/python/torch_mlir_e2e_test/torchscript/framework.py +++ b/python/torch_mlir_e2e_test/torchscript/framework.py @@ -23,14 +23,10 @@ compiling or TorchScript'ing). import abc from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict -import io -import pickle import traceback import torch -from .annotations import apply_serializable_annotations - TorchScriptValue = Union[int, float, List['TorchScriptValue'], Dict['TorchScriptValue', @@ -179,57 +175,6 @@ class Test(NamedTuple): program_invoker: Callable[[Any, TestUtils], None] -class SerializableTest(NamedTuple): - """A self-contained representation of a test that can be pickled. - - We use serialized TorchScript programs here for two reasons: - 1. The PyTorch pickling story isn't great, so in order to reliably pickle - this class, we rely on having the serialized bytes for the TorchScript - module already given to us. - 2. The choice of a TorchScript module vs `torch.nn.Module` boils down to - the fact that `torch.nn.Module` cannot be deserialized without pulling - in the same set of Python dependencies that were used to serialize it - in the first place. This would defeat one of the - main use cases of this class, which is to transport a test from an - environment with a set of heavy dependencies to a dependency-light one. - Since TorchScript modules are self-contained, they fit the bill - perfectly. - """ - # See unique_name on `Test`. - unique_name: str - # Serialized TorchScript program. - program: bytes - # Trace for execution testing. - trace: Trace - - def as_test(self) -> Test: - """Create a `Test` from this class.""" - # Conform the serialized program to the interface expected by Test. - # This is a bit of a hack, but it's the only way to keep the layering - # straight. - def factory(): - _extra_files = {"annotations.pkl": ""} - module = torch.jit.load(io.BytesIO(self.program), - _extra_files=_extra_files) - # Load the pickled annotations. - annotations = pickle.loads(_extra_files["annotations.pkl"]) - apply_serializable_annotations(module, annotations) - return module - - def invoker(module, tu): - for item in self.trace: - attr = module - for part in item.symbol.split("."): - attr = getattr(attr, part) - attr(*item.inputs) - - return Test( - unique_name=self.unique_name, - program_factory=factory, - program_invoker=invoker, - ) - - class TestResult(NamedTuple): # Stable unique name for error reporting and test suite configuration. # diff --git a/python/torch_mlir_e2e_test/torchscript/serialization.py b/python/torch_mlir_e2e_test/torchscript/serialization.py new file mode 100644 index 000000000..0f80c8a07 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/serialization.py @@ -0,0 +1,173 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +""" +# Serialization utilities for the end-to-end test framework. + +It is sometimes useful to be able to serialize tests to disk, such as when +multiple tests require mutally incompatible PyTorch versions. This takes +advantage of the strong backwards compatibility of serialized TorchScript, which +generally allows all such programs to be loaded to an appropriately recent +PyTorch. +""" + +from typing import List, Tuple, Optional, NamedTuple + +import io +import os +import pickle + +import torch + +from .framework import generate_golden_trace, Test, Trace +from .annotations import ArgAnnotation, export, annotate_args, TORCH_MLIR_EXPORT_ATTR_NAME, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME +from .registry import GLOBAL_TEST_REGISTRY + +# ============================================================================== +# Annotation serialization +# ============================================================================== + + +class SerializableMethodAnnotation(NamedTuple): + method_name: str + export: Optional[bool] + arg_annotations: Optional[List[ArgAnnotation]] + + +class SerializableModuleAnnotations(NamedTuple): + method_annotations: List[SerializableMethodAnnotation] + submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]] + + +def extract_serializable_annotations( + module: torch.nn.Module) -> SerializableModuleAnnotations: + module_annotations = SerializableModuleAnnotations( + method_annotations=[], submodule_annotations=[]) + # Extract information on methods. + for method_name, method in module.__dict__.items(): + # See if it is a method. + if not callable(method): + continue + export = None + arg_annotations = None + if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME): + export = method._torch_mlir_export + if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME): + arg_annotations = method._torch_mlir_arg_annotations + if export is not None and arg_annotations is not None: + module_annotations.method_annotations.append( + SerializableMethodAnnotation(method_name=method_name, + export=export, + arg_annotations=arg_annotations)) + + # Recurse. + for name, child in module.named_children(): + annotations = extract_serializable_annotations(child) + module_annotations.submodule_annotations.append((name, annotations)) + return module_annotations + + +def apply_serializable_annotations(module: torch.nn.Module, + annotations: SerializableModuleAnnotations): + # Apply annotations to methods. + for method_annotation in annotations.method_annotations: + # Imitate use of the decorators to keep a source of truth there. + if method_annotation.export is not None: + setattr(module, method_annotation.method_name, + export(getattr(module, method_annotation.method_name))) + if method_annotation.arg_annotations is not None: + setattr( + module, method_annotation.method_name, + annotate_args(method_annotation.arg_annotations)(getattr( + module, method_annotation.method_name))) + + # Recurse. + for name, submodule_annotations in annotations.submodule_annotations: + child = getattr(module, name) + apply_serializable_annotations(child, submodule_annotations) + +# ============================================================================== +# Serializable test definition +# ============================================================================== + + +class SerializableTest(NamedTuple): + """A self-contained representation of a test that can be pickled. + + We use serialized TorchScript programs here for two reasons: + 1. The PyTorch pickling story isn't great, so in order to reliably pickle + this class, we rely on having the serialized bytes for the TorchScript + module already given to us. + 2. The choice of a TorchScript module vs `torch.nn.Module` boils down to + the fact that `torch.nn.Module` cannot be deserialized without pulling + in the same set of Python dependencies that were used to serialize it + in the first place. This would defeat one of the + main use cases of this class, which is to transport a test from an + environment with a set of heavy dependencies to a dependency-light one. + Since TorchScript modules are self-contained, they fit the bill + perfectly. + """ + # See unique_name on `Test`. + unique_name: str + # Serialized TorchScript program. + program: bytes + # Trace for execution testing. + trace: Trace + + def as_test(self) -> Test: + """Create a `Test` from this class.""" + # Conform the serialized program to the interface expected by Test. + # This is a bit of a hack, but it's the only way to keep the layering + # straight. + def factory(): + _extra_files = {"annotations.pkl": ""} + module = torch.jit.load(io.BytesIO(self.program), + _extra_files=_extra_files) + # Load the pickled annotations. + annotations = pickle.loads(_extra_files["annotations.pkl"]) + apply_serializable_annotations(module, annotations) + return module + + def invoker(module, tu): + for item in self.trace: + attr = module + for part in item.symbol.split("."): + attr = getattr(attr, part) + attr(*item.inputs) + + return Test( + unique_name=self.unique_name, + program_factory=factory, + program_invoker=invoker, + ) + + +# ============================================================================== +# Filesystem operations +# ============================================================================== + +def serialize_all_tests_to(output_dir: str): + serializable_tests = [] + for test in GLOBAL_TEST_REGISTRY: + trace = generate_golden_trace(test) + module = torch.jit.script(test.program_factory()) + torchscript_module_bytes = module.save_to_buffer({ + "annotations.pkl": + pickle.dumps(extract_serializable_annotations(module)) + }) + serializable_tests.append( + SerializableTest(unique_name=test.unique_name, + program=torchscript_module_bytes, + trace=trace)) + for test in serializable_tests: + with open(os.path.join(output_dir, f"{test.unique_name}.pkl"), + "wb") as f: + pickle.dump(test, f) + + +def deserialize_all_tests_from(serialized_test_dir: str): + for root, _, files in os.walk(serialized_test_dir): + for filename in files: + with open(os.path.join(root, filename), 'rb') as f: + GLOBAL_TEST_REGISTRY.append(pickle.load(f).as_test())