diff --git a/build_tools/torchscript_e2e_heavydep_tests/README.md b/build_tools/torchscript_e2e_heavydep_tests/README.md deleted file mode 100644 index 5c7a34bce..000000000 --- a/build_tools/torchscript_e2e_heavydep_tests/README.md +++ /dev/null @@ -1,33 +0,0 @@ -### Additional end-to-end tests with heavy dependencies (heavydep tests) - -Some end-to-end tests require additional dependencies which don't make sense to -include as part of the default torch-mlir setup. Additionally, these -dependencies often don't work with the same HEAD PyTorch version that torch-mlir -builds against at the C++ level. - -We have a self-contained script that generates all the needed artifacts from a -self-contained virtual environment. It can be used like so: - -```shell -# Build the virtual environment in the specified directory and generate the -# serialized test artifacts in the other specified directory. -# This command is safe to re-run if you have already built the virtual -# environment and just changed the tests. -build_tools/e2e_heavydep_tests/generate_serialized_tests.sh \ - path/to/heavydep_venv \ - path/to/heavydep_serialized_tests - -# Add the --serialized-test-dir flag to point at the directory containing the -# serialized tests. All other functionality is the same as the normal invocation -# of e2e_test.sh, but the serialized tests will be available. -tools/e2e_test.sh --serialized-test-dir=path/to/heavydep_serialized_tests -``` - -The tests use the same (pure-Python) test framework as the normal -e2e_test.sh, but the tests are added in -`build_tools/e2e_heavydep_tests` instead of -`frontends/pytorch/e2e_testing/torchscript`. - -We rely critically on serialized TorchScript compatibility across PyTorch -versions to transport the tests + pure-Python compatibility of the `torch` -API, which has worked well so far. diff --git a/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh b/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh deleted file mode 100755 index ddda4d9a6..000000000 --- a/build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash -set -euo pipefail - -# Check that only two arugments are passed -if [ "$#" -ne 2 ]; then - echo "Usage: $0 " - echo 'Description: - venv_dir: directory to put the Python venv used for generating the serialized tests - serialized_test_dir: directory to write the generated serialized tests to -' - exit 1 -fi - -venv_dir=$1 -serialized_test_dir=$2 -here="$(realpath "$(dirname "$0")")" -torch_mlir_src_root="$here/../../" - -mkdir -p "$venv_dir" -mkdir -p "$serialized_test_dir" -python3 -m venv "$venv_dir" -source "$venv_dir"/bin/activate - -# latest torch-version and torch-vision module is required. -python3 -m pip install --upgrade -r "$torch_mlir_src_root/requirements.txt" - -# For minilm_seq_classification.py -python3 -m pip install 'transformers[torch]' - -# TODO: Remove functorch after make_fx makes into pytorch core. -python3 -m pip install "git+https://github.com/pytorch/functorch.git" -python3 -m pip install networkx - -# For pytorch image models. -python3 -m pip install timm - -cd "$torch_mlir_src_root" -export PYTHONPATH=${PYTHONPATH-} -source "$torch_mlir_src_root/.env" - -python3 -m build_tools.e2e_heavydep_tests.main --output_dir="$serialized_test_dir" diff --git a/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py b/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py deleted file mode 100644 index f387dafca..000000000 --- a/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py +++ /dev/null @@ -1,129 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import torch -from transformers import AutoTokenizer, AutoModelForSequenceClassification - -from torch_mlir_e2e_test.framework import TestUtils -from torch_mlir_e2e_test.registry import register_test_case -from torch_mlir_e2e_test.annotations import annotate_args, export - -torch.manual_seed(0) - - -def prepare_sentence_tokens(hf_model: str, sentence: str): - tokenizer = AutoTokenizer.from_pretrained(hf_model) - return torch.tensor([tokenizer.encode(sentence)]) - - -def getTracedRecursiveScriptModule(module, trace_input): - traced_module = torch.jit.trace_module(module, trace_input) - script_module = traced_module._actual_script_module - export(script_module.forward) - annotate_args_decorator = annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - annotate_args_decorator(script_module.forward) - return script_module - - -class HfSequenceClassification(torch.nn.Module): - - def __init__(self, model_name: str): - super().__init__() - self.model = AutoModelForSequenceClassification.from_pretrained( - model_name, # The pretrained model name. - # The number of output labels--2 for binary classification. - num_labels=2, - # Whether the model returns attentions weights. - output_attentions=False, - # Whether the model returns all hidden-states. - output_hidden_states=False, - torchscript=True, - ) - self.model.eval() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - def forward(self, tokens): - return self.model.forward(tokens)[0] - - -# ============================================================================== - -hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2" - -trace_input = { - "forward": - prepare_sentence_tokens(hf_minilm_model, "how do you like the project") -} -test_input = prepare_sentence_tokens(hf_minilm_model, - "this project is very interesting") - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - HfSequenceClassification(hf_minilm_model), trace_input)) -def MiniLMSequenceClassification_basic(module, tu: TestUtils): - module.forward(test_input) - - -# ============================================================================== - -hf_albert_model = "albert-base-v2" - -trace_input = { - "forward": - prepare_sentence_tokens(hf_albert_model, "how do you like the project") -} -test_input = prepare_sentence_tokens(hf_albert_model, - "this project is very interesting") - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - HfSequenceClassification(hf_albert_model), trace_input)) -def AlbertSequenceClassification_basic(module, tu: TestUtils): - module.forward(test_input) - - -# ============================================================================== - -hf_bert_model = "bert-base-uncased" - -trace_input = { - "forward": - prepare_sentence_tokens(hf_bert_model, "how do you like the project") -} -test_input = prepare_sentence_tokens(hf_bert_model, - "this project is very interesting") - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - HfSequenceClassification(hf_bert_model), trace_input)) -def BertSequenceClassification_basic(module, tu: TestUtils): - module.forward(test_input) - - -# ============================================================================== - -hf_distilbert_model = "distilbert-base-uncased" - -trace_input = { - "forward": - prepare_sentence_tokens(hf_distilbert_model, "how do you like the project") -} -test_input = prepare_sentence_tokens(hf_distilbert_model, - "this project is very interesting") - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - HfSequenceClassification(hf_distilbert_model), trace_input)) -def DistilBertSequenceClassification_basic(module, tu: TestUtils): - module.forward(test_input) - -# ============================================================================== diff --git a/build_tools/torchscript_e2e_heavydep_tests/main.py b/build_tools/torchscript_e2e_heavydep_tests/main.py deleted file mode 100644 index 28e374628..000000000 --- a/build_tools/torchscript_e2e_heavydep_tests/main.py +++ /dev/null @@ -1,28 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import argparse - -from torch_mlir_e2e_test.serialization import serialize_all_tests_to - -from . import hf_sequence_classification -from . import vision_models -from . import train_models - - -def _get_argparse(): - parser = argparse.ArgumentParser( - description="Generate assets for TorchScript E2E tests") - parser.add_argument("--output_dir", help="The directory to put assets in.") - return parser - - -def main(): - args = _get_argparse().parse_args() - serialize_all_tests_to(args.output_dir) - - -if __name__ == "__main__": - main() diff --git a/build_tools/torchscript_e2e_heavydep_tests/train_models.py b/build_tools/torchscript_e2e_heavydep_tests/train_models.py deleted file mode 100644 index d4d9e3999..000000000 --- a/build_tools/torchscript_e2e_heavydep_tests/train_models.py +++ /dev/null @@ -1,209 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import torch -from torch._decomp import get_decompositions -from functorch import make_fx -from torch.nn.utils import _stateless -from transformers import AutoTokenizer, AutoModelForSequenceClassification - -from torch_mlir_e2e_test.framework import TestUtils -from torch_mlir_e2e_test.registry import register_test_case -from torch_mlir_e2e_test.annotations import annotate_args, export - -from torch import fx -import copy -import tempfile - -torch.manual_seed(0) - -############################## Utility Functions ############################## - - -def get_input_annotations(inputs: tuple, dynamic: bool) -> list: - """Generates the annotation i.e., shape and dtype for the given inputs, required by torch-mlir module.""" - - annotations_list = [None] - for i in inputs: - temp_list = [] - if dynamic: - temp_list.append([-1 for i in range(len(i.shape))]) - else: - temp_list.append(list(i.shape)) - temp_list.append(i.dtype) - temp_list.append(True) - annotations_list.append(tuple(temp_list)) - return annotations_list - - -def change_fx_graph_return_to_tuple(fx_g: fx.GraphModule): - for node in fx_g.graph.nodes: - if node.op == "output": - # output nodes always have one argument - node_arg = node.args[0] - out_nodes = [] - if isinstance(node_arg, list): - # Don't return NoneType elements. - for out_node in node_arg: - if not isinstance(out_node, type(None)): - out_nodes.append(out_node) - # If there is a single tensor/element to be returned don't - # create a tuple for it. - if len(out_nodes) == 1: - node.args = out_nodes - else: - node.args = (tuple(out_nodes), ) - fx_g.graph.lint() - fx_g.recompile() - return fx_g - - -def generate_graph(model, inputs, training_fn): - # TODO: Pass the decomposition_table according to the model/needs. - fx_g = make_fx(training_fn, - decomposition_table=get_decompositions([ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward - ]))(dict(model.named_parameters()), - dict(model.named_buffers()), inputs) - fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) - fx_g.recompile() - fx_g = change_fx_graph_return_to_tuple(fx_g) - ts_g = torch.jit.script(fx_g) - # TODO: If not saved/load creates some unnecessary functions that - # causes problem during mlir-translate. - temp=tempfile.NamedTemporaryFile(suffix='_heavy_dep', - prefix='temp_ts_') - ts_g.save(temp.name) - new_ts = torch.jit.load(temp.name) - return new_ts - - -def getAnnotatedModule(ts_module, inputs): - export(ts_module.forward) - annotate_args_decorator = annotate_args( - get_input_annotations(inputs, dynamic=False)) - annotate_args_decorator(ts_module.forward) - return ts_module - - ############################################################################ - - -# Basic NeuralNet training test. -# This trains the Neural Net and returns the updated weights with the -# sgd optimzier. - - -class NeuralNet(torch.nn.Module): - - def __init__(self): - super(NeuralNet, self).__init__() - self.l1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.l2 = torch.nn.Linear(16, 2) - - def forward(self, x): - out = self.l1(x) - out = self.relu(out) - out = self.l2(out) - return out - - -neural_net_model = NeuralNet() -input = torch.randn(1, 10) - - -def get_sorted_params(named_params): - return [i[1] for i in sorted(named_params.items())] - - -# TODO: Pass and update the optimizer fn. Currently, we don't support passing -# elemental types. -def training_fn(params, buffers, args): - params_and_buffers = {**params, **buffers} - _stateless.functional_call(neural_net_model, params_and_buffers, args, - {}).sum().backward() - optim = torch.optim.SGD(get_sorted_params(params), lr=0.01) - optim.step() - return params, buffers - - -# We need to pass the model parameters, buffers and the inputs respectively in -# order. -training_inputs = [i.detach() for i in neural_net_model.parameters()] -for i in neural_net_model.buffers(): - training_inputs.append(i.detach()) - -training_inputs.append(input) - -neural_net_ts = generate_graph(neural_net_model, (input, ), training_fn) - -neural_net_ts_annotated = getAnnotatedModule(neural_net_ts, training_inputs) - - -@register_test_case(module_factory=lambda: neural_net_ts_annotated) -def NeuralNet_training_basic(module, tu: TestUtils): - module.forward(*training_inputs) - - -############################################################################## - -# Bert training. - - -class MiniLMSequenceClassification(torch.nn.Module): - - def __init__(self): - super().__init__() - self.model = AutoModelForSequenceClassification.from_pretrained( - # The below model is a variant of - # `microsoft/MiniLM-L12-H384-uncased` with less parameters. - "nreimers/MiniLM-L6-H384-uncased", # The pretrained model. - num_labels= - 2, # The number of output labels--2 for binary classification. - output_attentions= - False, # Whether the model returns attentions weights. - output_hidden_states= - False, # Whether the model returns all hidden-states. - torchscript=True, - ) - - def forward(self, tokens): - return self.model.forward(tokens)[0] - - -bert_model = MiniLMSequenceClassification() -input = torch.randint(2, (1, 128)) - - -# TODO: Pass and update the optimizer fn. Currently, we don't support passing -# elemental types. -def training_fn(params, buffers, args): - params_and_buffers = {**params, **buffers} - _stateless.functional_call(bert_model, params_and_buffers, args, - {}).sum().backward() - optim = torch.optim.SGD(get_sorted_params(params), lr=0.01) - optim.step() - return params, buffers - - -# We need to pass the model parameters, buffers and the inputs respectively in -# order. -bert_inputs = [i.detach() for i in bert_model.parameters()] -for i in bert_model.buffers(): - bert_inputs.append(i.detach()) - -bert_inputs.append(input) - -bert_ts = generate_graph(bert_model, (input, ), training_fn) - -bert_ts_annotated = getAnnotatedModule(bert_ts, bert_inputs) - - -@register_test_case(module_factory=lambda: bert_ts_annotated) -def BERT_training_basic(module, tu: TestUtils): - module.forward(*bert_inputs) diff --git a/build_tools/torchscript_e2e_heavydep_tests/vision_models.py b/build_tools/torchscript_e2e_heavydep_tests/vision_models.py deleted file mode 100644 index 71043eb98..000000000 --- a/build_tools/torchscript_e2e_heavydep_tests/vision_models.py +++ /dev/null @@ -1,263 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import torch -import torchvision.models as models - -from torch_mlir_e2e_test.framework import TestUtils -from torch_mlir_e2e_test.registry import register_test_case -from torch_mlir_e2e_test.annotations import annotate_args, export -import timm - -torch.manual_seed(0) - - -def getTracedRecursiveScriptModule(module): - script_module = torch.jit.script(module) - export(script_module.forward) - annotate_args_decorator = annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) - annotate_args_decorator(script_module.forward) - return script_module - - -class VisionModule(torch.nn.Module): - - def __init__(self, model): - super().__init__() - self.model = model - self.model.eval() - - def forward(self, input): - return self.model.forward(input) - - -# ============================================================================== - -resnet18_model = models.resnet18(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(resnet18_model))) -def Resnet18VisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(resnext50_32x4d_model))) -def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -mnasnet1_0_model = models.mnasnet1_0(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(mnasnet1_0_model))) -def Mnasnet1_0VisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -alexnet_model = models.alexnet(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(alexnet_model))) -def AlexnetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(shufflenet_model))) -def ShufflenetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -squeezenet_model = models.squeezenet1_0(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(squeezenet_model))) -def SqueezenetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -vgg16_model = models.vgg16(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(vgg16_model))) -def Vgg16VisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -wide_resnet_model = models.wide_resnet50_2(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(wide_resnet_model))) -def Wide_ResnetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -efficientnet_model = models.efficientnet_b0(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(efficientnet_model))) -def EfficientnetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -mobilenet_v2_model = models.mobilenet_v2(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(mobilenet_v2_model))) -def Mobilenet_v2VisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(mobilenet_v3_large_model))) -def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -resnet50_model = models.resnet50(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(resnet50_model))) -def Resnet50VisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -densenet121_model = models.densenet121(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(densenet121_model))) -def Densenet121VisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -timm_regnet_model = models.regnet_y_1_6gf(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(timm_regnet_model))) -def Timm_RegnetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -pytorch_unet_model = torch.hub.load( - "mateuszbuda/brain-segmentation-pytorch", - "unet", - in_channels=3, - out_channels=1, - init_features=32, - pretrained=True, -) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(pytorch_unet_model))) -def PytorchUnetVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -resnest_model = timm.create_model('resnest101e', pretrained=True) - -input = torch.randn(1, 3, 224, 224) - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(resnest_model))) -def ResnestVisionModel_basic(module, tu: TestUtils): - module.forward(input) - - -# ============================================================================== - -timm_vit_model = models.vit_b_16(pretrained=True) - -input = torch.randn(1, 3, 224, 224) - - -@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( - VisionModule(timm_vit_model))) -def ViTVisionModel_basic(module, tu: TestUtils): - module.forward(input) diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 7c5a18b8f..654b877ab 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -10,7 +10,6 @@ import sys from torch_mlir_e2e_test.framework import run_tests from torch_mlir_e2e_test.reporting import report_results from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY -from torch_mlir_e2e_test.serialization import deserialize_all_tests_from # Available test configs. @@ -57,14 +56,6 @@ Regular expression specifying which tests to include in this run. default=False, action='store_true', help='report test results with additional detail') - parser.add_argument('--serialized-test-dir', default=None, type=str, help=''' -The directory containing serialized pre-built tests. -Right now, these are additional tests which require heavy Python dependencies -to generate (or cannot even be generated with the version of PyTorch used by -torch-mlir). -See `build_tools/e2e_heavydep_tests/generate_serialized_tests.sh` -for more information on building these artifacts. -''') parser.add_argument('-s', '--sequential', default=False, action='store_true', @@ -79,8 +70,6 @@ which make it easier to attach a debugger or get a stack trace.''') def main(): args = _get_argparse().parse_args() - if args.serialized_test_dir: - deserialize_all_tests_from(args.serialized_test_dir) all_test_unique_names = set( test.unique_name for test in GLOBAL_TEST_REGISTRY) diff --git a/python/torch_mlir_e2e_test/serialization.py b/python/torch_mlir_e2e_test/serialization.py deleted file mode 100644 index 0f80c8a07..000000000 --- a/python/torch_mlir_e2e_test/serialization.py +++ /dev/null @@ -1,173 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. -""" -# Serialization utilities for the end-to-end test framework. - -It is sometimes useful to be able to serialize tests to disk, such as when -multiple tests require mutally incompatible PyTorch versions. This takes -advantage of the strong backwards compatibility of serialized TorchScript, which -generally allows all such programs to be loaded to an appropriately recent -PyTorch. -""" - -from typing import List, Tuple, Optional, NamedTuple - -import io -import os -import pickle - -import torch - -from .framework import generate_golden_trace, Test, Trace -from .annotations import ArgAnnotation, export, annotate_args, TORCH_MLIR_EXPORT_ATTR_NAME, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME -from .registry import GLOBAL_TEST_REGISTRY - -# ============================================================================== -# Annotation serialization -# ============================================================================== - - -class SerializableMethodAnnotation(NamedTuple): - method_name: str - export: Optional[bool] - arg_annotations: Optional[List[ArgAnnotation]] - - -class SerializableModuleAnnotations(NamedTuple): - method_annotations: List[SerializableMethodAnnotation] - submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]] - - -def extract_serializable_annotations( - module: torch.nn.Module) -> SerializableModuleAnnotations: - module_annotations = SerializableModuleAnnotations( - method_annotations=[], submodule_annotations=[]) - # Extract information on methods. - for method_name, method in module.__dict__.items(): - # See if it is a method. - if not callable(method): - continue - export = None - arg_annotations = None - if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME): - export = method._torch_mlir_export - if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME): - arg_annotations = method._torch_mlir_arg_annotations - if export is not None and arg_annotations is not None: - module_annotations.method_annotations.append( - SerializableMethodAnnotation(method_name=method_name, - export=export, - arg_annotations=arg_annotations)) - - # Recurse. - for name, child in module.named_children(): - annotations = extract_serializable_annotations(child) - module_annotations.submodule_annotations.append((name, annotations)) - return module_annotations - - -def apply_serializable_annotations(module: torch.nn.Module, - annotations: SerializableModuleAnnotations): - # Apply annotations to methods. - for method_annotation in annotations.method_annotations: - # Imitate use of the decorators to keep a source of truth there. - if method_annotation.export is not None: - setattr(module, method_annotation.method_name, - export(getattr(module, method_annotation.method_name))) - if method_annotation.arg_annotations is not None: - setattr( - module, method_annotation.method_name, - annotate_args(method_annotation.arg_annotations)(getattr( - module, method_annotation.method_name))) - - # Recurse. - for name, submodule_annotations in annotations.submodule_annotations: - child = getattr(module, name) - apply_serializable_annotations(child, submodule_annotations) - -# ============================================================================== -# Serializable test definition -# ============================================================================== - - -class SerializableTest(NamedTuple): - """A self-contained representation of a test that can be pickled. - - We use serialized TorchScript programs here for two reasons: - 1. The PyTorch pickling story isn't great, so in order to reliably pickle - this class, we rely on having the serialized bytes for the TorchScript - module already given to us. - 2. The choice of a TorchScript module vs `torch.nn.Module` boils down to - the fact that `torch.nn.Module` cannot be deserialized without pulling - in the same set of Python dependencies that were used to serialize it - in the first place. This would defeat one of the - main use cases of this class, which is to transport a test from an - environment with a set of heavy dependencies to a dependency-light one. - Since TorchScript modules are self-contained, they fit the bill - perfectly. - """ - # See unique_name on `Test`. - unique_name: str - # Serialized TorchScript program. - program: bytes - # Trace for execution testing. - trace: Trace - - def as_test(self) -> Test: - """Create a `Test` from this class.""" - # Conform the serialized program to the interface expected by Test. - # This is a bit of a hack, but it's the only way to keep the layering - # straight. - def factory(): - _extra_files = {"annotations.pkl": ""} - module = torch.jit.load(io.BytesIO(self.program), - _extra_files=_extra_files) - # Load the pickled annotations. - annotations = pickle.loads(_extra_files["annotations.pkl"]) - apply_serializable_annotations(module, annotations) - return module - - def invoker(module, tu): - for item in self.trace: - attr = module - for part in item.symbol.split("."): - attr = getattr(attr, part) - attr(*item.inputs) - - return Test( - unique_name=self.unique_name, - program_factory=factory, - program_invoker=invoker, - ) - - -# ============================================================================== -# Filesystem operations -# ============================================================================== - -def serialize_all_tests_to(output_dir: str): - serializable_tests = [] - for test in GLOBAL_TEST_REGISTRY: - trace = generate_golden_trace(test) - module = torch.jit.script(test.program_factory()) - torchscript_module_bytes = module.save_to_buffer({ - "annotations.pkl": - pickle.dumps(extract_serializable_annotations(module)) - }) - serializable_tests.append( - SerializableTest(unique_name=test.unique_name, - program=torchscript_module_bytes, - trace=trace)) - for test in serializable_tests: - with open(os.path.join(output_dir, f"{test.unique_name}.pkl"), - "wb") as f: - pickle.dump(test, f) - - -def deserialize_all_tests_from(serialized_test_dir: str): - for root, _, files in os.walk(serialized_test_dir): - for filename in files: - with open(os.path.join(root, filename), 'rb') as f: - GLOBAL_TEST_REGISTRY.append(pickle.load(f).as_test())