mirror of https://github.com/llvm/torch-mlir
Centralize all test serialization logic.
parent
e59a91620a
commit
0378c75b35
|
@ -4,14 +4,8 @@
|
|||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate_golden_trace
|
||||
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
||||
from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
|
||||
|
||||
from . import hf_sequence_classification
|
||||
|
||||
|
@ -25,22 +19,7 @@ def _get_argparse():
|
|||
|
||||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
serializable_tests = []
|
||||
for test in GLOBAL_TEST_REGISTRY:
|
||||
trace = generate_golden_trace(test)
|
||||
module = torch.jit.script(test.program_factory())
|
||||
torchscript_module_bytes = module.save_to_buffer({
|
||||
"annotations.pkl":
|
||||
pickle.dumps(extract_serializable_annotations(module))
|
||||
})
|
||||
serializable_tests.append(
|
||||
SerializableTest(unique_name=test.unique_name,
|
||||
program=torchscript_module_bytes,
|
||||
trace=trace))
|
||||
for test in serializable_tests:
|
||||
with open(os.path.join(args.output_dir, f"{test.unique_name}.pkl"),
|
||||
"wb") as f:
|
||||
pickle.dump(test, f)
|
||||
serialize_all_tests_to(args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -12,6 +12,7 @@ import sys
|
|||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, run_tests
|
||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.torchscript.serialization import deserialize_all_tests_from
|
||||
|
||||
# Available test configs.
|
||||
from torch_mlir_e2e_test.torchscript.configs import (
|
||||
|
@ -72,13 +73,10 @@ for more information on building these artifacts.
|
|||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
|
||||
all_tests = list(GLOBAL_TEST_REGISTRY)
|
||||
if args.serialized_test_dir:
|
||||
for root, dirs, files in os.walk(args.serialized_test_dir):
|
||||
for filename in files:
|
||||
with open(os.path.join(root, filename), 'rb') as f:
|
||||
all_tests.append(pickle.load(f).as_test())
|
||||
all_test_unique_names = set(test.unique_name for test in all_tests)
|
||||
deserialize_all_tests_from(args.serialized_test_dir)
|
||||
all_test_unique_names = set(
|
||||
test.unique_name for test in GLOBAL_TEST_REGISTRY)
|
||||
|
||||
# Find the selected config.
|
||||
if args.config == 'refbackend':
|
||||
|
@ -117,7 +115,7 @@ def main():
|
|||
|
||||
# Find the selected tests, and emit a diagnostic if none are found.
|
||||
tests = [
|
||||
test for test in all_tests
|
||||
test for test in GLOBAL_TEST_REGISTRY
|
||||
if re.match(args.filter, test.unique_name)
|
||||
]
|
||||
if len(tests) == 0:
|
||||
|
@ -125,7 +123,7 @@ def main():
|
|||
f'ERROR: the provided filter {args.filter!r} does not match any tests'
|
||||
)
|
||||
print('The available tests are:')
|
||||
for test in all_tests:
|
||||
for test in GLOBAL_TEST_REGISTRY:
|
||||
print(test.unique_name)
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -68,62 +68,3 @@ def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
|||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class SerializableMethodAnnotation(NamedTuple):
|
||||
method_name: str
|
||||
export: Optional[bool]
|
||||
arg_annotations: Optional[List[ArgAnnotation]]
|
||||
|
||||
|
||||
class SerializableModuleAnnotations(NamedTuple):
|
||||
method_annotations: List[SerializableMethodAnnotation]
|
||||
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
|
||||
|
||||
|
||||
def extract_serializable_annotations(
|
||||
module: torch.nn.Module) -> SerializableModuleAnnotations:
|
||||
module_annotations = SerializableModuleAnnotations(
|
||||
method_annotations=[], submodule_annotations=[])
|
||||
# Extract information on methods.
|
||||
for method_name, method in module.__dict__.items():
|
||||
# See if it is a method.
|
||||
if not callable(method):
|
||||
continue
|
||||
export = None
|
||||
arg_annotations = None
|
||||
if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME):
|
||||
export = method._torch_mlir_export
|
||||
if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME):
|
||||
arg_annotations = method._torch_mlir_arg_annotations
|
||||
if export is not None and arg_annotations is not None:
|
||||
module_annotations.method_annotations.append(
|
||||
SerializableMethodAnnotation(method_name=method_name,
|
||||
export=export,
|
||||
arg_annotations=arg_annotations))
|
||||
|
||||
# Recurse.
|
||||
for name, child in module.named_children():
|
||||
annotations = extract_serializable_annotations(child)
|
||||
module_annotations.submodule_annotations.append((name, annotations))
|
||||
return module_annotations
|
||||
|
||||
|
||||
def apply_serializable_annotations(module: torch.nn.Module,
|
||||
annotations: SerializableModuleAnnotations):
|
||||
# Apply annotations to methods.
|
||||
for method_annotation in annotations.method_annotations:
|
||||
# Imitate use of the decorators to keep a source of truth there.
|
||||
if method_annotation.export is not None:
|
||||
setattr(module, method_annotation.method_name,
|
||||
export(getattr(module, method_annotation.method_name)))
|
||||
if method_annotation.arg_annotations is not None:
|
||||
setattr(
|
||||
module, method_annotation.method_name,
|
||||
annotate_args(method_annotation.arg_annotations)(getattr(
|
||||
module, method_annotation.method_name)))
|
||||
|
||||
# Recurse.
|
||||
for name, submodule_annotations in annotations.submodule_annotations:
|
||||
child = getattr(module, name)
|
||||
apply_serializable_annotations(child, submodule_annotations)
|
||||
|
|
|
@ -23,14 +23,10 @@ compiling or TorchScript'ing).
|
|||
import abc
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict
|
||||
|
||||
import io
|
||||
import pickle
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
|
||||
from .annotations import apply_serializable_annotations
|
||||
|
||||
|
||||
TorchScriptValue = Union[int, float, List['TorchScriptValue'],
|
||||
Dict['TorchScriptValue',
|
||||
|
@ -179,57 +175,6 @@ class Test(NamedTuple):
|
|||
program_invoker: Callable[[Any, TestUtils], None]
|
||||
|
||||
|
||||
class SerializableTest(NamedTuple):
|
||||
"""A self-contained representation of a test that can be pickled.
|
||||
|
||||
We use serialized TorchScript programs here for two reasons:
|
||||
1. The PyTorch pickling story isn't great, so in order to reliably pickle
|
||||
this class, we rely on having the serialized bytes for the TorchScript
|
||||
module already given to us.
|
||||
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
|
||||
the fact that `torch.nn.Module` cannot be deserialized without pulling
|
||||
in the same set of Python dependencies that were used to serialize it
|
||||
in the first place. This would defeat one of the
|
||||
main use cases of this class, which is to transport a test from an
|
||||
environment with a set of heavy dependencies to a dependency-light one.
|
||||
Since TorchScript modules are self-contained, they fit the bill
|
||||
perfectly.
|
||||
"""
|
||||
# See unique_name on `Test`.
|
||||
unique_name: str
|
||||
# Serialized TorchScript program.
|
||||
program: bytes
|
||||
# Trace for execution testing.
|
||||
trace: Trace
|
||||
|
||||
def as_test(self) -> Test:
|
||||
"""Create a `Test` from this class."""
|
||||
# Conform the serialized program to the interface expected by Test.
|
||||
# This is a bit of a hack, but it's the only way to keep the layering
|
||||
# straight.
|
||||
def factory():
|
||||
_extra_files = {"annotations.pkl": ""}
|
||||
module = torch.jit.load(io.BytesIO(self.program),
|
||||
_extra_files=_extra_files)
|
||||
# Load the pickled annotations.
|
||||
annotations = pickle.loads(_extra_files["annotations.pkl"])
|
||||
apply_serializable_annotations(module, annotations)
|
||||
return module
|
||||
|
||||
def invoker(module, tu):
|
||||
for item in self.trace:
|
||||
attr = module
|
||||
for part in item.symbol.split("."):
|
||||
attr = getattr(attr, part)
|
||||
attr(*item.inputs)
|
||||
|
||||
return Test(
|
||||
unique_name=self.unique_name,
|
||||
program_factory=factory,
|
||||
program_invoker=invoker,
|
||||
)
|
||||
|
||||
|
||||
class TestResult(NamedTuple):
|
||||
# Stable unique name for error reporting and test suite configuration.
|
||||
#
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
"""
|
||||
# Serialization utilities for the end-to-end test framework.
|
||||
|
||||
It is sometimes useful to be able to serialize tests to disk, such as when
|
||||
multiple tests require mutally incompatible PyTorch versions. This takes
|
||||
advantage of the strong backwards compatibility of serialized TorchScript, which
|
||||
generally allows all such programs to be loaded to an appropriately recent
|
||||
PyTorch.
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Optional, NamedTuple
|
||||
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
from .framework import generate_golden_trace, Test, Trace
|
||||
from .annotations import ArgAnnotation, export, annotate_args, TORCH_MLIR_EXPORT_ATTR_NAME, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||
from .registry import GLOBAL_TEST_REGISTRY
|
||||
|
||||
# ==============================================================================
|
||||
# Annotation serialization
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SerializableMethodAnnotation(NamedTuple):
|
||||
method_name: str
|
||||
export: Optional[bool]
|
||||
arg_annotations: Optional[List[ArgAnnotation]]
|
||||
|
||||
|
||||
class SerializableModuleAnnotations(NamedTuple):
|
||||
method_annotations: List[SerializableMethodAnnotation]
|
||||
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
|
||||
|
||||
|
||||
def extract_serializable_annotations(
|
||||
module: torch.nn.Module) -> SerializableModuleAnnotations:
|
||||
module_annotations = SerializableModuleAnnotations(
|
||||
method_annotations=[], submodule_annotations=[])
|
||||
# Extract information on methods.
|
||||
for method_name, method in module.__dict__.items():
|
||||
# See if it is a method.
|
||||
if not callable(method):
|
||||
continue
|
||||
export = None
|
||||
arg_annotations = None
|
||||
if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME):
|
||||
export = method._torch_mlir_export
|
||||
if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME):
|
||||
arg_annotations = method._torch_mlir_arg_annotations
|
||||
if export is not None and arg_annotations is not None:
|
||||
module_annotations.method_annotations.append(
|
||||
SerializableMethodAnnotation(method_name=method_name,
|
||||
export=export,
|
||||
arg_annotations=arg_annotations))
|
||||
|
||||
# Recurse.
|
||||
for name, child in module.named_children():
|
||||
annotations = extract_serializable_annotations(child)
|
||||
module_annotations.submodule_annotations.append((name, annotations))
|
||||
return module_annotations
|
||||
|
||||
|
||||
def apply_serializable_annotations(module: torch.nn.Module,
|
||||
annotations: SerializableModuleAnnotations):
|
||||
# Apply annotations to methods.
|
||||
for method_annotation in annotations.method_annotations:
|
||||
# Imitate use of the decorators to keep a source of truth there.
|
||||
if method_annotation.export is not None:
|
||||
setattr(module, method_annotation.method_name,
|
||||
export(getattr(module, method_annotation.method_name)))
|
||||
if method_annotation.arg_annotations is not None:
|
||||
setattr(
|
||||
module, method_annotation.method_name,
|
||||
annotate_args(method_annotation.arg_annotations)(getattr(
|
||||
module, method_annotation.method_name)))
|
||||
|
||||
# Recurse.
|
||||
for name, submodule_annotations in annotations.submodule_annotations:
|
||||
child = getattr(module, name)
|
||||
apply_serializable_annotations(child, submodule_annotations)
|
||||
|
||||
# ==============================================================================
|
||||
# Serializable test definition
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SerializableTest(NamedTuple):
|
||||
"""A self-contained representation of a test that can be pickled.
|
||||
|
||||
We use serialized TorchScript programs here for two reasons:
|
||||
1. The PyTorch pickling story isn't great, so in order to reliably pickle
|
||||
this class, we rely on having the serialized bytes for the TorchScript
|
||||
module already given to us.
|
||||
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
|
||||
the fact that `torch.nn.Module` cannot be deserialized without pulling
|
||||
in the same set of Python dependencies that were used to serialize it
|
||||
in the first place. This would defeat one of the
|
||||
main use cases of this class, which is to transport a test from an
|
||||
environment with a set of heavy dependencies to a dependency-light one.
|
||||
Since TorchScript modules are self-contained, they fit the bill
|
||||
perfectly.
|
||||
"""
|
||||
# See unique_name on `Test`.
|
||||
unique_name: str
|
||||
# Serialized TorchScript program.
|
||||
program: bytes
|
||||
# Trace for execution testing.
|
||||
trace: Trace
|
||||
|
||||
def as_test(self) -> Test:
|
||||
"""Create a `Test` from this class."""
|
||||
# Conform the serialized program to the interface expected by Test.
|
||||
# This is a bit of a hack, but it's the only way to keep the layering
|
||||
# straight.
|
||||
def factory():
|
||||
_extra_files = {"annotations.pkl": ""}
|
||||
module = torch.jit.load(io.BytesIO(self.program),
|
||||
_extra_files=_extra_files)
|
||||
# Load the pickled annotations.
|
||||
annotations = pickle.loads(_extra_files["annotations.pkl"])
|
||||
apply_serializable_annotations(module, annotations)
|
||||
return module
|
||||
|
||||
def invoker(module, tu):
|
||||
for item in self.trace:
|
||||
attr = module
|
||||
for part in item.symbol.split("."):
|
||||
attr = getattr(attr, part)
|
||||
attr(*item.inputs)
|
||||
|
||||
return Test(
|
||||
unique_name=self.unique_name,
|
||||
program_factory=factory,
|
||||
program_invoker=invoker,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Filesystem operations
|
||||
# ==============================================================================
|
||||
|
||||
def serialize_all_tests_to(output_dir: str):
|
||||
serializable_tests = []
|
||||
for test in GLOBAL_TEST_REGISTRY:
|
||||
trace = generate_golden_trace(test)
|
||||
module = torch.jit.script(test.program_factory())
|
||||
torchscript_module_bytes = module.save_to_buffer({
|
||||
"annotations.pkl":
|
||||
pickle.dumps(extract_serializable_annotations(module))
|
||||
})
|
||||
serializable_tests.append(
|
||||
SerializableTest(unique_name=test.unique_name,
|
||||
program=torchscript_module_bytes,
|
||||
trace=trace))
|
||||
for test in serializable_tests:
|
||||
with open(os.path.join(output_dir, f"{test.unique_name}.pkl"),
|
||||
"wb") as f:
|
||||
pickle.dump(test, f)
|
||||
|
||||
|
||||
def deserialize_all_tests_from(serialized_test_dir: str):
|
||||
for root, _, files in os.walk(serialized_test_dir):
|
||||
for filename in files:
|
||||
with open(os.path.join(root, filename), 'rb') as f:
|
||||
GLOBAL_TEST_REGISTRY.append(pickle.load(f).as_test())
|
Loading…
Reference in New Issue