Centralize all test serialization logic.

pull/709/head
Sean Silva 2022-03-25 21:47:11 +00:00
parent e59a91620a
commit 0378c75b35
5 changed files with 181 additions and 145 deletions

View File

@ -4,14 +4,8 @@
# Also available under a BSD-style license. See LICENSE.
import argparse
import os
import pickle
import torch
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate_golden_trace
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
from . import hf_sequence_classification
@ -25,22 +19,7 @@ def _get_argparse():
def main():
args = _get_argparse().parse_args()
serializable_tests = []
for test in GLOBAL_TEST_REGISTRY:
trace = generate_golden_trace(test)
module = torch.jit.script(test.program_factory())
torchscript_module_bytes = module.save_to_buffer({
"annotations.pkl":
pickle.dumps(extract_serializable_annotations(module))
})
serializable_tests.append(
SerializableTest(unique_name=test.unique_name,
program=torchscript_module_bytes,
trace=trace))
for test in serializable_tests:
with open(os.path.join(args.output_dir, f"{test.unique_name}.pkl"),
"wb") as f:
pickle.dump(test, f)
serialize_all_tests_to(args.output_dir)
if __name__ == "__main__":

View File

@ -12,6 +12,7 @@ import sys
from torch_mlir_e2e_test.torchscript.framework import TestConfig, run_tests
from torch_mlir_e2e_test.torchscript.reporting import report_results
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_from
# Available test configs.
from torch_mlir_e2e_test.torchscript.configs import (
@ -72,13 +73,10 @@ for more information on building these artifacts.
def main():
args = _get_argparse().parse_args()
all_tests = list(GLOBAL_TEST_REGISTRY)
if args.serialized_test_dir:
for root, dirs, files in os.walk(args.serialized_test_dir):
for filename in files:
with open(os.path.join(root, filename), 'rb') as f:
all_tests.append(pickle.load(f).as_test())
all_test_unique_names = set(test.unique_name for test in all_tests)
deserialize_all_tests_from(args.serialized_test_dir)
all_test_unique_names = set(
test.unique_name for test in GLOBAL_TEST_REGISTRY)
# Find the selected config.
if args.config == 'refbackend':
@ -117,7 +115,7 @@ def main():
# Find the selected tests, and emit a diagnostic if none are found.
tests = [
test for test in all_tests
test for test in GLOBAL_TEST_REGISTRY
if re.match(args.filter, test.unique_name)
]
if len(tests) == 0:
@ -125,7 +123,7 @@ def main():
f'ERROR: the provided filter {args.filter!r} does not match any tests'
)
print('The available tests are:')
for test in all_tests:
for test in GLOBAL_TEST_REGISTRY:
print(test.unique_name)
sys.exit(1)

View File

@ -68,62 +68,3 @@ def annotate_args(annotations: List[Optional[ArgAnnotation]]):
return fn
return decorator
class SerializableMethodAnnotation(NamedTuple):
method_name: str
export: Optional[bool]
arg_annotations: Optional[List[ArgAnnotation]]
class SerializableModuleAnnotations(NamedTuple):
method_annotations: List[SerializableMethodAnnotation]
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
def extract_serializable_annotations(
module: torch.nn.Module) -> SerializableModuleAnnotations:
module_annotations = SerializableModuleAnnotations(
method_annotations=[], submodule_annotations=[])
# Extract information on methods.
for method_name, method in module.__dict__.items():
# See if it is a method.
if not callable(method):
continue
export = None
arg_annotations = None
if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME):
export = method._torch_mlir_export
if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME):
arg_annotations = method._torch_mlir_arg_annotations
if export is not None and arg_annotations is not None:
module_annotations.method_annotations.append(
SerializableMethodAnnotation(method_name=method_name,
export=export,
arg_annotations=arg_annotations))
# Recurse.
for name, child in module.named_children():
annotations = extract_serializable_annotations(child)
module_annotations.submodule_annotations.append((name, annotations))
return module_annotations
def apply_serializable_annotations(module: torch.nn.Module,
annotations: SerializableModuleAnnotations):
# Apply annotations to methods.
for method_annotation in annotations.method_annotations:
# Imitate use of the decorators to keep a source of truth there.
if method_annotation.export is not None:
setattr(module, method_annotation.method_name,
export(getattr(module, method_annotation.method_name)))
if method_annotation.arg_annotations is not None:
setattr(
module, method_annotation.method_name,
annotate_args(method_annotation.arg_annotations)(getattr(
module, method_annotation.method_name)))
# Recurse.
for name, submodule_annotations in annotations.submodule_annotations:
child = getattr(module, name)
apply_serializable_annotations(child, submodule_annotations)

View File

@ -23,14 +23,10 @@ compiling or TorchScript'ing).
import abc
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict
import io
import pickle
import traceback
import torch
from .annotations import apply_serializable_annotations
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
Dict['TorchScriptValue',
@ -179,57 +175,6 @@ class Test(NamedTuple):
program_invoker: Callable[[Any, TestUtils], None]
class SerializableTest(NamedTuple):
"""A self-contained representation of a test that can be pickled.
We use serialized TorchScript programs here for two reasons:
1. The PyTorch pickling story isn't great, so in order to reliably pickle
this class, we rely on having the serialized bytes for the TorchScript
module already given to us.
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
the fact that `torch.nn.Module` cannot be deserialized without pulling
in the same set of Python dependencies that were used to serialize it
in the first place. This would defeat one of the
main use cases of this class, which is to transport a test from an
environment with a set of heavy dependencies to a dependency-light one.
Since TorchScript modules are self-contained, they fit the bill
perfectly.
"""
# See unique_name on `Test`.
unique_name: str
# Serialized TorchScript program.
program: bytes
# Trace for execution testing.
trace: Trace
def as_test(self) -> Test:
"""Create a `Test` from this class."""
# Conform the serialized program to the interface expected by Test.
# This is a bit of a hack, but it's the only way to keep the layering
# straight.
def factory():
_extra_files = {"annotations.pkl": ""}
module = torch.jit.load(io.BytesIO(self.program),
_extra_files=_extra_files)
# Load the pickled annotations.
annotations = pickle.loads(_extra_files["annotations.pkl"])
apply_serializable_annotations(module, annotations)
return module
def invoker(module, tu):
for item in self.trace:
attr = module
for part in item.symbol.split("."):
attr = getattr(attr, part)
attr(*item.inputs)
return Test(
unique_name=self.unique_name,
program_factory=factory,
program_invoker=invoker,
)
class TestResult(NamedTuple):
# Stable unique name for error reporting and test suite configuration.
#

View File

@ -0,0 +1,173 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
# Serialization utilities for the end-to-end test framework.
It is sometimes useful to be able to serialize tests to disk, such as when
multiple tests require mutally incompatible PyTorch versions. This takes
advantage of the strong backwards compatibility of serialized TorchScript, which
generally allows all such programs to be loaded to an appropriately recent
PyTorch.
"""
from typing import List, Tuple, Optional, NamedTuple
import io
import os
import pickle
import torch
from .framework import generate_golden_trace, Test, Trace
from .annotations import ArgAnnotation, export, annotate_args, TORCH_MLIR_EXPORT_ATTR_NAME, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
from .registry import GLOBAL_TEST_REGISTRY
# ==============================================================================
# Annotation serialization
# ==============================================================================
class SerializableMethodAnnotation(NamedTuple):
method_name: str
export: Optional[bool]
arg_annotations: Optional[List[ArgAnnotation]]
class SerializableModuleAnnotations(NamedTuple):
method_annotations: List[SerializableMethodAnnotation]
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
def extract_serializable_annotations(
module: torch.nn.Module) -> SerializableModuleAnnotations:
module_annotations = SerializableModuleAnnotations(
method_annotations=[], submodule_annotations=[])
# Extract information on methods.
for method_name, method in module.__dict__.items():
# See if it is a method.
if not callable(method):
continue
export = None
arg_annotations = None
if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME):
export = method._torch_mlir_export
if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME):
arg_annotations = method._torch_mlir_arg_annotations
if export is not None and arg_annotations is not None:
module_annotations.method_annotations.append(
SerializableMethodAnnotation(method_name=method_name,
export=export,
arg_annotations=arg_annotations))
# Recurse.
for name, child in module.named_children():
annotations = extract_serializable_annotations(child)
module_annotations.submodule_annotations.append((name, annotations))
return module_annotations
def apply_serializable_annotations(module: torch.nn.Module,
annotations: SerializableModuleAnnotations):
# Apply annotations to methods.
for method_annotation in annotations.method_annotations:
# Imitate use of the decorators to keep a source of truth there.
if method_annotation.export is not None:
setattr(module, method_annotation.method_name,
export(getattr(module, method_annotation.method_name)))
if method_annotation.arg_annotations is not None:
setattr(
module, method_annotation.method_name,
annotate_args(method_annotation.arg_annotations)(getattr(
module, method_annotation.method_name)))
# Recurse.
for name, submodule_annotations in annotations.submodule_annotations:
child = getattr(module, name)
apply_serializable_annotations(child, submodule_annotations)
# ==============================================================================
# Serializable test definition
# ==============================================================================
class SerializableTest(NamedTuple):
"""A self-contained representation of a test that can be pickled.
We use serialized TorchScript programs here for two reasons:
1. The PyTorch pickling story isn't great, so in order to reliably pickle
this class, we rely on having the serialized bytes for the TorchScript
module already given to us.
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
the fact that `torch.nn.Module` cannot be deserialized without pulling
in the same set of Python dependencies that were used to serialize it
in the first place. This would defeat one of the
main use cases of this class, which is to transport a test from an
environment with a set of heavy dependencies to a dependency-light one.
Since TorchScript modules are self-contained, they fit the bill
perfectly.
"""
# See unique_name on `Test`.
unique_name: str
# Serialized TorchScript program.
program: bytes
# Trace for execution testing.
trace: Trace
def as_test(self) -> Test:
"""Create a `Test` from this class."""
# Conform the serialized program to the interface expected by Test.
# This is a bit of a hack, but it's the only way to keep the layering
# straight.
def factory():
_extra_files = {"annotations.pkl": ""}
module = torch.jit.load(io.BytesIO(self.program),
_extra_files=_extra_files)
# Load the pickled annotations.
annotations = pickle.loads(_extra_files["annotations.pkl"])
apply_serializable_annotations(module, annotations)
return module
def invoker(module, tu):
for item in self.trace:
attr = module
for part in item.symbol.split("."):
attr = getattr(attr, part)
attr(*item.inputs)
return Test(
unique_name=self.unique_name,
program_factory=factory,
program_invoker=invoker,
)
# ==============================================================================
# Filesystem operations
# ==============================================================================
def serialize_all_tests_to(output_dir: str):
serializable_tests = []
for test in GLOBAL_TEST_REGISTRY:
trace = generate_golden_trace(test)
module = torch.jit.script(test.program_factory())
torchscript_module_bytes = module.save_to_buffer({
"annotations.pkl":
pickle.dumps(extract_serializable_annotations(module))
})
serializable_tests.append(
SerializableTest(unique_name=test.unique_name,
program=torchscript_module_bytes,
trace=trace))
for test in serializable_tests:
with open(os.path.join(output_dir, f"{test.unique_name}.pkl"),
"wb") as f:
pickle.dump(test, f)
def deserialize_all_tests_from(serialized_test_dir: str):
for root, _, files in os.walk(serialized_test_dir):
for filename in files:
with open(os.path.join(root, filename), 'rb') as f:
GLOBAL_TEST_REGISTRY.append(pickle.load(f).as_test())