Remove the heavydep tests

We originally added these to help bring up more complex models with
heavier dependencies. However, over time it has become clear that these
models usually require more than just heavier dependencies -- they often
require a nontrivial amount of "one-off" code to extract the relevant
parts of the model and compile them. This is not a good fit for a
component in the core Torch-MLIR repo.

However, in the community, nod.ai has developed the ["Shark
Tank"](https://github.com/nod-ai/SHARK/tree/main/tank) which has all the
appropriate code to wrangle these models and organize them. We intend to
more heaviliy lean on that as a community and improve the symbiosis
there to serve the role that these heavydep tests were meant to play.
pull/1476/head
Sean Silva 2022-10-11 12:12:04 +00:00
parent 6403c0e56f
commit c8280d67bd
8 changed files with 0 additions and 887 deletions

View File

@ -1,33 +0,0 @@
### Additional end-to-end tests with heavy dependencies (heavydep tests)
Some end-to-end tests require additional dependencies which don't make sense to
include as part of the default torch-mlir setup. Additionally, these
dependencies often don't work with the same HEAD PyTorch version that torch-mlir
builds against at the C++ level.
We have a self-contained script that generates all the needed artifacts from a
self-contained virtual environment. It can be used like so:
```shell
# Build the virtual environment in the specified directory and generate the
# serialized test artifacts in the other specified directory.
# This command is safe to re-run if you have already built the virtual
# environment and just changed the tests.
build_tools/e2e_heavydep_tests/generate_serialized_tests.sh \
path/to/heavydep_venv \
path/to/heavydep_serialized_tests
# Add the --serialized-test-dir flag to point at the directory containing the
# serialized tests. All other functionality is the same as the normal invocation
# of e2e_test.sh, but the serialized tests will be available.
tools/e2e_test.sh --serialized-test-dir=path/to/heavydep_serialized_tests
```
The tests use the same (pure-Python) test framework as the normal
e2e_test.sh, but the tests are added in
`build_tools/e2e_heavydep_tests` instead of
`frontends/pytorch/e2e_testing/torchscript`.
We rely critically on serialized TorchScript compatibility across PyTorch
versions to transport the tests + pure-Python compatibility of the `torch`
API, which has worked well so far.

View File

@ -1,41 +0,0 @@
#!/bin/bash
set -euo pipefail
# Check that only two arugments are passed
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <venv_dir> <serialized_test_dir>"
echo 'Description:
venv_dir: directory to put the Python venv used for generating the serialized tests
serialized_test_dir: directory to write the generated serialized tests to
'
exit 1
fi
venv_dir=$1
serialized_test_dir=$2
here="$(realpath "$(dirname "$0")")"
torch_mlir_src_root="$here/../../"
mkdir -p "$venv_dir"
mkdir -p "$serialized_test_dir"
python3 -m venv "$venv_dir"
source "$venv_dir"/bin/activate
# latest torch-version and torch-vision module is required.
python3 -m pip install --upgrade -r "$torch_mlir_src_root/requirements.txt"
# For minilm_seq_classification.py
python3 -m pip install 'transformers[torch]'
# TODO: Remove functorch after make_fx makes into pytorch core.
python3 -m pip install "git+https://github.com/pytorch/functorch.git"
python3 -m pip install networkx
# For pytorch image models.
python3 -m pip install timm
cd "$torch_mlir_src_root"
export PYTHONPATH=${PYTHONPATH-}
source "$torch_mlir_src_root/.env"
python3 -m build_tools.e2e_heavydep_tests.main --output_dir="$serialized_test_dir"

View File

@ -1,129 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
torch.manual_seed(0)
def prepare_sentence_tokens(hf_model: str, sentence: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
return torch.tensor([tokenizer.encode(sentence)])
def getTracedRecursiveScriptModule(module, trace_input):
traced_module = torch.jit.trace_module(module, trace_input)
script_module = traced_module._actual_script_module
export(script_module.forward)
annotate_args_decorator = annotate_args([
None,
([-1, -1], torch.int64, True),
])
annotate_args_decorator(script_module.forward)
return script_module
class HfSequenceClassification(torch.nn.Module):
def __init__(self, model_name: str):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, # The pretrained model name.
# The number of output labels--2 for binary classification.
num_labels=2,
# Whether the model returns attentions weights.
output_attentions=False,
# Whether the model returns all hidden-states.
output_hidden_states=False,
torchscript=True,
)
self.model.eval()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, tokens):
return self.model.forward(tokens)[0]
# ==============================================================================
hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2"
trace_input = {
"forward":
prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_minilm_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_minilm_model), trace_input))
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_albert_model = "albert-base-v2"
trace_input = {
"forward":
prepare_sentence_tokens(hf_albert_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_albert_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_albert_model), trace_input))
def AlbertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_bert_model = "bert-base-uncased"
trace_input = {
"forward":
prepare_sentence_tokens(hf_bert_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_bert_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_bert_model), trace_input))
def BertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_distilbert_model = "distilbert-base-uncased"
trace_input = {
"forward":
prepare_sentence_tokens(hf_distilbert_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_distilbert_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_distilbert_model), trace_input))
def DistilBertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================

View File

@ -1,28 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import argparse
from torch_mlir_e2e_test.serialization import serialize_all_tests_to
from . import hf_sequence_classification
from . import vision_models
from . import train_models
def _get_argparse():
parser = argparse.ArgumentParser(
description="Generate assets for TorchScript E2E tests")
parser.add_argument("--output_dir", help="The directory to put assets in.")
return parser
def main():
args = _get_argparse().parse_args()
serialize_all_tests_to(args.output_dir)
if __name__ == "__main__":
main()

View File

@ -1,209 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch._decomp import get_decompositions
from functorch import make_fx
from torch.nn.utils import _stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
from torch import fx
import copy
import tempfile
torch.manual_seed(0)
############################## Utility Functions ##############################
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
"""Generates the annotation i.e., shape and dtype for the given inputs, required by torch-mlir module."""
annotations_list = [None]
for i in inputs:
temp_list = []
if dynamic:
temp_list.append([-1 for i in range(len(i.shape))])
else:
temp_list.append(list(i.shape))
temp_list.append(i.dtype)
temp_list.append(True)
annotations_list.append(tuple(temp_list))
return annotations_list
def change_fx_graph_return_to_tuple(fx_g: fx.GraphModule):
for node in fx_g.graph.nodes:
if node.op == "output":
# output nodes always have one argument
node_arg = node.args[0]
out_nodes = []
if isinstance(node_arg, list):
# Don't return NoneType elements.
for out_node in node_arg:
if not isinstance(out_node, type(None)):
out_nodes.append(out_node)
# If there is a single tensor/element to be returned don't
# create a tuple for it.
if len(out_nodes) == 1:
node.args = out_nodes
else:
node.args = (tuple(out_nodes), )
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def generate_graph(model, inputs, training_fn):
# TODO: Pass the decomposition_table according to the model/needs.
fx_g = make_fx(training_fn,
decomposition_table=get_decompositions([
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward
]))(dict(model.named_parameters()),
dict(model.named_buffers()), inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
fx_g = change_fx_graph_return_to_tuple(fx_g)
ts_g = torch.jit.script(fx_g)
# TODO: If not saved/load creates some unnecessary functions that
# causes problem during mlir-translate.
temp=tempfile.NamedTemporaryFile(suffix='_heavy_dep',
prefix='temp_ts_')
ts_g.save(temp.name)
new_ts = torch.jit.load(temp.name)
return new_ts
def getAnnotatedModule(ts_module, inputs):
export(ts_module.forward)
annotate_args_decorator = annotate_args(
get_input_annotations(inputs, dynamic=False))
annotate_args_decorator(ts_module.forward)
return ts_module
############################################################################
# Basic NeuralNet training test.
# This trains the Neural Net and returns the updated weights with the
# sgd optimzier.
class NeuralNet(torch.nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.l1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.l2 = torch.nn.Linear(16, 2)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
return out
neural_net_model = NeuralNet()
input = torch.randn(1, 10)
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
# TODO: Pass and update the optimizer fn. Currently, we don't support passing
# elemental types.
def training_fn(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(neural_net_model, params_and_buffers, args,
{}).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
optim.step()
return params, buffers
# We need to pass the model parameters, buffers and the inputs respectively in
# order.
training_inputs = [i.detach() for i in neural_net_model.parameters()]
for i in neural_net_model.buffers():
training_inputs.append(i.detach())
training_inputs.append(input)
neural_net_ts = generate_graph(neural_net_model, (input, ), training_fn)
neural_net_ts_annotated = getAnnotatedModule(neural_net_ts, training_inputs)
@register_test_case(module_factory=lambda: neural_net_ts_annotated)
def NeuralNet_training_basic(module, tu: TestUtils):
module.forward(*training_inputs)
##############################################################################
# Bert training.
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
# The below model is a variant of
# `microsoft/MiniLM-L12-H384-uncased` with less parameters.
"nreimers/MiniLM-L6-H384-uncased", # The pretrained model.
num_labels=
2, # The number of output labels--2 for binary classification.
output_attentions=
False, # Whether the model returns attentions weights.
output_hidden_states=
False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, tokens):
return self.model.forward(tokens)[0]
bert_model = MiniLMSequenceClassification()
input = torch.randint(2, (1, 128))
# TODO: Pass and update the optimizer fn. Currently, we don't support passing
# elemental types.
def training_fn(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(bert_model, params_and_buffers, args,
{}).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
optim.step()
return params, buffers
# We need to pass the model parameters, buffers and the inputs respectively in
# order.
bert_inputs = [i.detach() for i in bert_model.parameters()]
for i in bert_model.buffers():
bert_inputs.append(i.detach())
bert_inputs.append(input)
bert_ts = generate_graph(bert_model, (input, ), training_fn)
bert_ts_annotated = getAnnotatedModule(bert_ts, bert_inputs)
@register_test_case(module_factory=lambda: bert_ts_annotated)
def BERT_training_basic(module, tu: TestUtils):
module.forward(*bert_inputs)

View File

@ -1,263 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
import torchvision.models as models
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
import timm
torch.manual_seed(0)
def getTracedRecursiveScriptModule(module):
script_module = torch.jit.script(module)
export(script_module.forward)
annotate_args_decorator = annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
annotate_args_decorator(script_module.forward)
return script_module
class VisionModule(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.eval()
def forward(self, input):
return self.model.forward(input)
# ==============================================================================
resnet18_model = models.resnet18(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnet18_model)))
def Resnet18VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnext50_32x4d_model)))
def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
mnasnet1_0_model = models.mnasnet1_0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(mnasnet1_0_model)))
def Mnasnet1_0VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
alexnet_model = models.alexnet(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(alexnet_model)))
def AlexnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(shufflenet_model)))
def ShufflenetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
squeezenet_model = models.squeezenet1_0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(squeezenet_model)))
def SqueezenetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
vgg16_model = models.vgg16(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(vgg16_model)))
def Vgg16VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
wide_resnet_model = models.wide_resnet50_2(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(wide_resnet_model)))
def Wide_ResnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
efficientnet_model = models.efficientnet_b0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(efficientnet_model)))
def EfficientnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
mobilenet_v2_model = models.mobilenet_v2(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(mobilenet_v2_model)))
def Mobilenet_v2VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(mobilenet_v3_large_model)))
def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
resnet50_model = models.resnet50(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnet50_model)))
def Resnet50VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
densenet121_model = models.densenet121(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(densenet121_model)))
def Densenet121VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
timm_regnet_model = models.regnet_y_1_6gf(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(timm_regnet_model)))
def Timm_RegnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
pytorch_unet_model = torch.hub.load(
"mateuszbuda/brain-segmentation-pytorch",
"unet",
in_channels=3,
out_channels=1,
init_features=32,
pretrained=True,
)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(pytorch_unet_model)))
def PytorchUnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
resnest_model = timm.create_model('resnest101e', pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnest_model)))
def ResnestVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
timm_vit_model = models.vit_b_16(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(timm_vit_model)))
def ViTVisionModel_basic(module, tu: TestUtils):
module.forward(input)

View File

@ -10,7 +10,6 @@ import sys
from torch_mlir_e2e_test.framework import run_tests
from torch_mlir_e2e_test.reporting import report_results
from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.serialization import deserialize_all_tests_from
# Available test configs.
@ -57,14 +56,6 @@ Regular expression specifying which tests to include in this run.
default=False,
action='store_true',
help='report test results with additional detail')
parser.add_argument('--serialized-test-dir', default=None, type=str, help='''
The directory containing serialized pre-built tests.
Right now, these are additional tests which require heavy Python dependencies
to generate (or cannot even be generated with the version of PyTorch used by
torch-mlir).
See `build_tools/e2e_heavydep_tests/generate_serialized_tests.sh`
for more information on building these artifacts.
''')
parser.add_argument('-s', '--sequential',
default=False,
action='store_true',
@ -79,8 +70,6 @@ which make it easier to attach a debugger or get a stack trace.''')
def main():
args = _get_argparse().parse_args()
if args.serialized_test_dir:
deserialize_all_tests_from(args.serialized_test_dir)
all_test_unique_names = set(
test.unique_name for test in GLOBAL_TEST_REGISTRY)

View File

@ -1,173 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
# Serialization utilities for the end-to-end test framework.
It is sometimes useful to be able to serialize tests to disk, such as when
multiple tests require mutally incompatible PyTorch versions. This takes
advantage of the strong backwards compatibility of serialized TorchScript, which
generally allows all such programs to be loaded to an appropriately recent
PyTorch.
"""
from typing import List, Tuple, Optional, NamedTuple
import io
import os
import pickle
import torch
from .framework import generate_golden_trace, Test, Trace
from .annotations import ArgAnnotation, export, annotate_args, TORCH_MLIR_EXPORT_ATTR_NAME, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
from .registry import GLOBAL_TEST_REGISTRY
# ==============================================================================
# Annotation serialization
# ==============================================================================
class SerializableMethodAnnotation(NamedTuple):
method_name: str
export: Optional[bool]
arg_annotations: Optional[List[ArgAnnotation]]
class SerializableModuleAnnotations(NamedTuple):
method_annotations: List[SerializableMethodAnnotation]
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
def extract_serializable_annotations(
module: torch.nn.Module) -> SerializableModuleAnnotations:
module_annotations = SerializableModuleAnnotations(
method_annotations=[], submodule_annotations=[])
# Extract information on methods.
for method_name, method in module.__dict__.items():
# See if it is a method.
if not callable(method):
continue
export = None
arg_annotations = None
if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME):
export = method._torch_mlir_export
if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME):
arg_annotations = method._torch_mlir_arg_annotations
if export is not None and arg_annotations is not None:
module_annotations.method_annotations.append(
SerializableMethodAnnotation(method_name=method_name,
export=export,
arg_annotations=arg_annotations))
# Recurse.
for name, child in module.named_children():
annotations = extract_serializable_annotations(child)
module_annotations.submodule_annotations.append((name, annotations))
return module_annotations
def apply_serializable_annotations(module: torch.nn.Module,
annotations: SerializableModuleAnnotations):
# Apply annotations to methods.
for method_annotation in annotations.method_annotations:
# Imitate use of the decorators to keep a source of truth there.
if method_annotation.export is not None:
setattr(module, method_annotation.method_name,
export(getattr(module, method_annotation.method_name)))
if method_annotation.arg_annotations is not None:
setattr(
module, method_annotation.method_name,
annotate_args(method_annotation.arg_annotations)(getattr(
module, method_annotation.method_name)))
# Recurse.
for name, submodule_annotations in annotations.submodule_annotations:
child = getattr(module, name)
apply_serializable_annotations(child, submodule_annotations)
# ==============================================================================
# Serializable test definition
# ==============================================================================
class SerializableTest(NamedTuple):
"""A self-contained representation of a test that can be pickled.
We use serialized TorchScript programs here for two reasons:
1. The PyTorch pickling story isn't great, so in order to reliably pickle
this class, we rely on having the serialized bytes for the TorchScript
module already given to us.
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
the fact that `torch.nn.Module` cannot be deserialized without pulling
in the same set of Python dependencies that were used to serialize it
in the first place. This would defeat one of the
main use cases of this class, which is to transport a test from an
environment with a set of heavy dependencies to a dependency-light one.
Since TorchScript modules are self-contained, they fit the bill
perfectly.
"""
# See unique_name on `Test`.
unique_name: str
# Serialized TorchScript program.
program: bytes
# Trace for execution testing.
trace: Trace
def as_test(self) -> Test:
"""Create a `Test` from this class."""
# Conform the serialized program to the interface expected by Test.
# This is a bit of a hack, but it's the only way to keep the layering
# straight.
def factory():
_extra_files = {"annotations.pkl": ""}
module = torch.jit.load(io.BytesIO(self.program),
_extra_files=_extra_files)
# Load the pickled annotations.
annotations = pickle.loads(_extra_files["annotations.pkl"])
apply_serializable_annotations(module, annotations)
return module
def invoker(module, tu):
for item in self.trace:
attr = module
for part in item.symbol.split("."):
attr = getattr(attr, part)
attr(*item.inputs)
return Test(
unique_name=self.unique_name,
program_factory=factory,
program_invoker=invoker,
)
# ==============================================================================
# Filesystem operations
# ==============================================================================
def serialize_all_tests_to(output_dir: str):
serializable_tests = []
for test in GLOBAL_TEST_REGISTRY:
trace = generate_golden_trace(test)
module = torch.jit.script(test.program_factory())
torchscript_module_bytes = module.save_to_buffer({
"annotations.pkl":
pickle.dumps(extract_serializable_annotations(module))
})
serializable_tests.append(
SerializableTest(unique_name=test.unique_name,
program=torchscript_module_bytes,
trace=trace))
for test in serializable_tests:
with open(os.path.join(output_dir, f"{test.unique_name}.pkl"),
"wb") as f:
pickle.dump(test, f)
def deserialize_all_tests_from(serialized_test_dir: str):
for root, _, files in os.walk(serialized_test_dir):
for filename in files:
with open(os.path.join(root, filename), 'rb') as f:
GLOBAL_TEST_REGISTRY.append(pickle.load(f).as_test())