mirror of https://github.com/llvm/torch-mlir
Remove the heavydep tests
We originally added these to help bring up more complex models with heavier dependencies. However, over time it has become clear that these models usually require more than just heavier dependencies -- they often require a nontrivial amount of "one-off" code to extract the relevant parts of the model and compile them. This is not a good fit for a component in the core Torch-MLIR repo. However, in the community, nod.ai has developed the ["Shark Tank"](https://github.com/nod-ai/SHARK/tree/main/tank) which has all the appropriate code to wrangle these models and organize them. We intend to more heaviliy lean on that as a community and improve the symbiosis there to serve the role that these heavydep tests were meant to play.pull/1476/head
parent
6403c0e56f
commit
c8280d67bd
|
@ -1,33 +0,0 @@
|
|||
### Additional end-to-end tests with heavy dependencies (heavydep tests)
|
||||
|
||||
Some end-to-end tests require additional dependencies which don't make sense to
|
||||
include as part of the default torch-mlir setup. Additionally, these
|
||||
dependencies often don't work with the same HEAD PyTorch version that torch-mlir
|
||||
builds against at the C++ level.
|
||||
|
||||
We have a self-contained script that generates all the needed artifacts from a
|
||||
self-contained virtual environment. It can be used like so:
|
||||
|
||||
```shell
|
||||
# Build the virtual environment in the specified directory and generate the
|
||||
# serialized test artifacts in the other specified directory.
|
||||
# This command is safe to re-run if you have already built the virtual
|
||||
# environment and just changed the tests.
|
||||
build_tools/e2e_heavydep_tests/generate_serialized_tests.sh \
|
||||
path/to/heavydep_venv \
|
||||
path/to/heavydep_serialized_tests
|
||||
|
||||
# Add the --serialized-test-dir flag to point at the directory containing the
|
||||
# serialized tests. All other functionality is the same as the normal invocation
|
||||
# of e2e_test.sh, but the serialized tests will be available.
|
||||
tools/e2e_test.sh --serialized-test-dir=path/to/heavydep_serialized_tests
|
||||
```
|
||||
|
||||
The tests use the same (pure-Python) test framework as the normal
|
||||
e2e_test.sh, but the tests are added in
|
||||
`build_tools/e2e_heavydep_tests` instead of
|
||||
`frontends/pytorch/e2e_testing/torchscript`.
|
||||
|
||||
We rely critically on serialized TorchScript compatibility across PyTorch
|
||||
versions to transport the tests + pure-Python compatibility of the `torch`
|
||||
API, which has worked well so far.
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Check that only two arugments are passed
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <venv_dir> <serialized_test_dir>"
|
||||
echo 'Description:
|
||||
venv_dir: directory to put the Python venv used for generating the serialized tests
|
||||
serialized_test_dir: directory to write the generated serialized tests to
|
||||
'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
venv_dir=$1
|
||||
serialized_test_dir=$2
|
||||
here="$(realpath "$(dirname "$0")")"
|
||||
torch_mlir_src_root="$here/../../"
|
||||
|
||||
mkdir -p "$venv_dir"
|
||||
mkdir -p "$serialized_test_dir"
|
||||
python3 -m venv "$venv_dir"
|
||||
source "$venv_dir"/bin/activate
|
||||
|
||||
# latest torch-version and torch-vision module is required.
|
||||
python3 -m pip install --upgrade -r "$torch_mlir_src_root/requirements.txt"
|
||||
|
||||
# For minilm_seq_classification.py
|
||||
python3 -m pip install 'transformers[torch]'
|
||||
|
||||
# TODO: Remove functorch after make_fx makes into pytorch core.
|
||||
python3 -m pip install "git+https://github.com/pytorch/functorch.git"
|
||||
python3 -m pip install networkx
|
||||
|
||||
# For pytorch image models.
|
||||
python3 -m pip install timm
|
||||
|
||||
cd "$torch_mlir_src_root"
|
||||
export PYTHONPATH=${PYTHONPATH-}
|
||||
source "$torch_mlir_src_root/.env"
|
||||
|
||||
python3 -m build_tools.e2e_heavydep_tests.main --output_dir="$serialized_test_dir"
|
|
@ -1,129 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def prepare_sentence_tokens(hf_model: str, sentence: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
def getTracedRecursiveScriptModule(module, trace_input):
|
||||
traced_module = torch.jit.trace_module(module, trace_input)
|
||||
script_module = traced_module._actual_script_module
|
||||
export(script_module.forward)
|
||||
annotate_args_decorator = annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
annotate_args_decorator(script_module.forward)
|
||||
return script_module
|
||||
|
||||
|
||||
class HfSequenceClassification(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, # The pretrained model name.
|
||||
# The number of output labels--2 for binary classification.
|
||||
num_labels=2,
|
||||
# Whether the model returns attentions weights.
|
||||
output_attentions=False,
|
||||
# Whether the model returns all hidden-states.
|
||||
output_hidden_states=False,
|
||||
torchscript=True,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_minilm_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_minilm_model), trace_input))
|
||||
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_albert_model = "albert-base-v2"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_albert_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_albert_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_albert_model), trace_input))
|
||||
def AlbertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_bert_model = "bert-base-uncased"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_bert_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_bert_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_bert_model), trace_input))
|
||||
def BertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_distilbert_model = "distilbert-base-uncased"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_distilbert_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_distilbert_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_distilbert_model), trace_input))
|
||||
def DistilBertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
# ==============================================================================
|
|
@ -1,28 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import argparse
|
||||
|
||||
from torch_mlir_e2e_test.serialization import serialize_all_tests_to
|
||||
|
||||
from . import hf_sequence_classification
|
||||
from . import vision_models
|
||||
from . import train_models
|
||||
|
||||
|
||||
def _get_argparse():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate assets for TorchScript E2E tests")
|
||||
parser.add_argument("--output_dir", help="The directory to put assets in.")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
serialize_all_tests_to(args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,209 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
from torch._decomp import get_decompositions
|
||||
from functorch import make_fx
|
||||
from torch.nn.utils import _stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
from torch import fx
|
||||
import copy
|
||||
import tempfile
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
############################## Utility Functions ##############################
|
||||
|
||||
|
||||
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
"""Generates the annotation i.e., shape and dtype for the given inputs, required by torch-mlir module."""
|
||||
|
||||
annotations_list = [None]
|
||||
for i in inputs:
|
||||
temp_list = []
|
||||
if dynamic:
|
||||
temp_list.append([-1 for i in range(len(i.shape))])
|
||||
else:
|
||||
temp_list.append(list(i.shape))
|
||||
temp_list.append(i.dtype)
|
||||
temp_list.append(True)
|
||||
annotations_list.append(tuple(temp_list))
|
||||
return annotations_list
|
||||
|
||||
|
||||
def change_fx_graph_return_to_tuple(fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# create a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes), )
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
|
||||
def generate_graph(model, inputs, training_fn):
|
||||
# TODO: Pass the decomposition_table according to the model/needs.
|
||||
fx_g = make_fx(training_fn,
|
||||
decomposition_table=get_decompositions([
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward
|
||||
]))(dict(model.named_parameters()),
|
||||
dict(model.named_buffers()), inputs)
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
fx_g = change_fx_graph_return_to_tuple(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
# TODO: If not saved/load creates some unnecessary functions that
|
||||
# causes problem during mlir-translate.
|
||||
temp=tempfile.NamedTemporaryFile(suffix='_heavy_dep',
|
||||
prefix='temp_ts_')
|
||||
ts_g.save(temp.name)
|
||||
new_ts = torch.jit.load(temp.name)
|
||||
return new_ts
|
||||
|
||||
|
||||
def getAnnotatedModule(ts_module, inputs):
|
||||
export(ts_module.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(inputs, dynamic=False))
|
||||
annotate_args_decorator(ts_module.forward)
|
||||
return ts_module
|
||||
|
||||
############################################################################
|
||||
|
||||
|
||||
# Basic NeuralNet training test.
|
||||
# This trains the Neural Net and returns the updated weights with the
|
||||
# sgd optimzier.
|
||||
|
||||
|
||||
class NeuralNet(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(NeuralNet, self).__init__()
|
||||
self.l1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.l2 = torch.nn.Linear(16, 2)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.l1(x)
|
||||
out = self.relu(out)
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
|
||||
neural_net_model = NeuralNet()
|
||||
input = torch.randn(1, 10)
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
# TODO: Pass and update the optimizer fn. Currently, we don't support passing
|
||||
# elemental types.
|
||||
def training_fn(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(neural_net_model, params_and_buffers, args,
|
||||
{}).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
|
||||
# We need to pass the model parameters, buffers and the inputs respectively in
|
||||
# order.
|
||||
training_inputs = [i.detach() for i in neural_net_model.parameters()]
|
||||
for i in neural_net_model.buffers():
|
||||
training_inputs.append(i.detach())
|
||||
|
||||
training_inputs.append(input)
|
||||
|
||||
neural_net_ts = generate_graph(neural_net_model, (input, ), training_fn)
|
||||
|
||||
neural_net_ts_annotated = getAnnotatedModule(neural_net_ts, training_inputs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: neural_net_ts_annotated)
|
||||
def NeuralNet_training_basic(module, tu: TestUtils):
|
||||
module.forward(*training_inputs)
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
||||
# Bert training.
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
# The below model is a variant of
|
||||
# `microsoft/MiniLM-L12-H384-uncased` with less parameters.
|
||||
"nreimers/MiniLM-L6-H384-uncased", # The pretrained model.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
bert_model = MiniLMSequenceClassification()
|
||||
input = torch.randint(2, (1, 128))
|
||||
|
||||
|
||||
# TODO: Pass and update the optimizer fn. Currently, we don't support passing
|
||||
# elemental types.
|
||||
def training_fn(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(bert_model, params_and_buffers, args,
|
||||
{}).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
|
||||
# We need to pass the model parameters, buffers and the inputs respectively in
|
||||
# order.
|
||||
bert_inputs = [i.detach() for i in bert_model.parameters()]
|
||||
for i in bert_model.buffers():
|
||||
bert_inputs.append(i.detach())
|
||||
|
||||
bert_inputs.append(input)
|
||||
|
||||
bert_ts = generate_graph(bert_model, (input, ), training_fn)
|
||||
|
||||
bert_ts_annotated = getAnnotatedModule(bert_ts, bert_inputs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: bert_ts_annotated)
|
||||
def BERT_training_basic(module, tu: TestUtils):
|
||||
module.forward(*bert_inputs)
|
|
@ -1,263 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
import timm
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def getTracedRecursiveScriptModule(module):
|
||||
script_module = torch.jit.script(module)
|
||||
export(script_module.forward)
|
||||
annotate_args_decorator = annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
annotate_args_decorator(script_module.forward)
|
||||
return script_module
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnet18_model = models.resnet18(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnet18_model)))
|
||||
def Resnet18VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnext50_32x4d_model)))
|
||||
def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
mnasnet1_0_model = models.mnasnet1_0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(mnasnet1_0_model)))
|
||||
def Mnasnet1_0VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
alexnet_model = models.alexnet(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(alexnet_model)))
|
||||
def AlexnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(shufflenet_model)))
|
||||
def ShufflenetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
squeezenet_model = models.squeezenet1_0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(squeezenet_model)))
|
||||
def SqueezenetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
vgg16_model = models.vgg16(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(vgg16_model)))
|
||||
def Vgg16VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
wide_resnet_model = models.wide_resnet50_2(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(wide_resnet_model)))
|
||||
def Wide_ResnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
efficientnet_model = models.efficientnet_b0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(efficientnet_model)))
|
||||
def EfficientnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
mobilenet_v2_model = models.mobilenet_v2(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(mobilenet_v2_model)))
|
||||
def Mobilenet_v2VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(mobilenet_v3_large_model)))
|
||||
def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnet50_model = models.resnet50(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnet50_model)))
|
||||
def Resnet50VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
densenet121_model = models.densenet121(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(densenet121_model)))
|
||||
def Densenet121VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
timm_regnet_model = models.regnet_y_1_6gf(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(timm_regnet_model)))
|
||||
def Timm_RegnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
pytorch_unet_model = torch.hub.load(
|
||||
"mateuszbuda/brain-segmentation-pytorch",
|
||||
"unet",
|
||||
in_channels=3,
|
||||
out_channels=1,
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(pytorch_unet_model)))
|
||||
def PytorchUnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnest_model = timm.create_model('resnest101e', pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnest_model)))
|
||||
def ResnestVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
timm_vit_model = models.vit_b_16(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(timm_vit_model)))
|
||||
def ViTVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
|
@ -10,7 +10,6 @@ import sys
|
|||
from torch_mlir_e2e_test.framework import run_tests
|
||||
from torch_mlir_e2e_test.reporting import report_results
|
||||
from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir_e2e_test.serialization import deserialize_all_tests_from
|
||||
|
||||
|
||||
# Available test configs.
|
||||
|
@ -57,14 +56,6 @@ Regular expression specifying which tests to include in this run.
|
|||
default=False,
|
||||
action='store_true',
|
||||
help='report test results with additional detail')
|
||||
parser.add_argument('--serialized-test-dir', default=None, type=str, help='''
|
||||
The directory containing serialized pre-built tests.
|
||||
Right now, these are additional tests which require heavy Python dependencies
|
||||
to generate (or cannot even be generated with the version of PyTorch used by
|
||||
torch-mlir).
|
||||
See `build_tools/e2e_heavydep_tests/generate_serialized_tests.sh`
|
||||
for more information on building these artifacts.
|
||||
''')
|
||||
parser.add_argument('-s', '--sequential',
|
||||
default=False,
|
||||
action='store_true',
|
||||
|
@ -79,8 +70,6 @@ which make it easier to attach a debugger or get a stack trace.''')
|
|||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
|
||||
if args.serialized_test_dir:
|
||||
deserialize_all_tests_from(args.serialized_test_dir)
|
||||
all_test_unique_names = set(
|
||||
test.unique_name for test in GLOBAL_TEST_REGISTRY)
|
||||
|
||||
|
|
|
@ -1,173 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
"""
|
||||
# Serialization utilities for the end-to-end test framework.
|
||||
|
||||
It is sometimes useful to be able to serialize tests to disk, such as when
|
||||
multiple tests require mutally incompatible PyTorch versions. This takes
|
||||
advantage of the strong backwards compatibility of serialized TorchScript, which
|
||||
generally allows all such programs to be loaded to an appropriately recent
|
||||
PyTorch.
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Optional, NamedTuple
|
||||
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
|
||||
from .framework import generate_golden_trace, Test, Trace
|
||||
from .annotations import ArgAnnotation, export, annotate_args, TORCH_MLIR_EXPORT_ATTR_NAME, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||
from .registry import GLOBAL_TEST_REGISTRY
|
||||
|
||||
# ==============================================================================
|
||||
# Annotation serialization
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SerializableMethodAnnotation(NamedTuple):
|
||||
method_name: str
|
||||
export: Optional[bool]
|
||||
arg_annotations: Optional[List[ArgAnnotation]]
|
||||
|
||||
|
||||
class SerializableModuleAnnotations(NamedTuple):
|
||||
method_annotations: List[SerializableMethodAnnotation]
|
||||
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
|
||||
|
||||
|
||||
def extract_serializable_annotations(
|
||||
module: torch.nn.Module) -> SerializableModuleAnnotations:
|
||||
module_annotations = SerializableModuleAnnotations(
|
||||
method_annotations=[], submodule_annotations=[])
|
||||
# Extract information on methods.
|
||||
for method_name, method in module.__dict__.items():
|
||||
# See if it is a method.
|
||||
if not callable(method):
|
||||
continue
|
||||
export = None
|
||||
arg_annotations = None
|
||||
if hasattr(method, TORCH_MLIR_EXPORT_ATTR_NAME):
|
||||
export = method._torch_mlir_export
|
||||
if hasattr(method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME):
|
||||
arg_annotations = method._torch_mlir_arg_annotations
|
||||
if export is not None and arg_annotations is not None:
|
||||
module_annotations.method_annotations.append(
|
||||
SerializableMethodAnnotation(method_name=method_name,
|
||||
export=export,
|
||||
arg_annotations=arg_annotations))
|
||||
|
||||
# Recurse.
|
||||
for name, child in module.named_children():
|
||||
annotations = extract_serializable_annotations(child)
|
||||
module_annotations.submodule_annotations.append((name, annotations))
|
||||
return module_annotations
|
||||
|
||||
|
||||
def apply_serializable_annotations(module: torch.nn.Module,
|
||||
annotations: SerializableModuleAnnotations):
|
||||
# Apply annotations to methods.
|
||||
for method_annotation in annotations.method_annotations:
|
||||
# Imitate use of the decorators to keep a source of truth there.
|
||||
if method_annotation.export is not None:
|
||||
setattr(module, method_annotation.method_name,
|
||||
export(getattr(module, method_annotation.method_name)))
|
||||
if method_annotation.arg_annotations is not None:
|
||||
setattr(
|
||||
module, method_annotation.method_name,
|
||||
annotate_args(method_annotation.arg_annotations)(getattr(
|
||||
module, method_annotation.method_name)))
|
||||
|
||||
# Recurse.
|
||||
for name, submodule_annotations in annotations.submodule_annotations:
|
||||
child = getattr(module, name)
|
||||
apply_serializable_annotations(child, submodule_annotations)
|
||||
|
||||
# ==============================================================================
|
||||
# Serializable test definition
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SerializableTest(NamedTuple):
|
||||
"""A self-contained representation of a test that can be pickled.
|
||||
|
||||
We use serialized TorchScript programs here for two reasons:
|
||||
1. The PyTorch pickling story isn't great, so in order to reliably pickle
|
||||
this class, we rely on having the serialized bytes for the TorchScript
|
||||
module already given to us.
|
||||
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
|
||||
the fact that `torch.nn.Module` cannot be deserialized without pulling
|
||||
in the same set of Python dependencies that were used to serialize it
|
||||
in the first place. This would defeat one of the
|
||||
main use cases of this class, which is to transport a test from an
|
||||
environment with a set of heavy dependencies to a dependency-light one.
|
||||
Since TorchScript modules are self-contained, they fit the bill
|
||||
perfectly.
|
||||
"""
|
||||
# See unique_name on `Test`.
|
||||
unique_name: str
|
||||
# Serialized TorchScript program.
|
||||
program: bytes
|
||||
# Trace for execution testing.
|
||||
trace: Trace
|
||||
|
||||
def as_test(self) -> Test:
|
||||
"""Create a `Test` from this class."""
|
||||
# Conform the serialized program to the interface expected by Test.
|
||||
# This is a bit of a hack, but it's the only way to keep the layering
|
||||
# straight.
|
||||
def factory():
|
||||
_extra_files = {"annotations.pkl": ""}
|
||||
module = torch.jit.load(io.BytesIO(self.program),
|
||||
_extra_files=_extra_files)
|
||||
# Load the pickled annotations.
|
||||
annotations = pickle.loads(_extra_files["annotations.pkl"])
|
||||
apply_serializable_annotations(module, annotations)
|
||||
return module
|
||||
|
||||
def invoker(module, tu):
|
||||
for item in self.trace:
|
||||
attr = module
|
||||
for part in item.symbol.split("."):
|
||||
attr = getattr(attr, part)
|
||||
attr(*item.inputs)
|
||||
|
||||
return Test(
|
||||
unique_name=self.unique_name,
|
||||
program_factory=factory,
|
||||
program_invoker=invoker,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Filesystem operations
|
||||
# ==============================================================================
|
||||
|
||||
def serialize_all_tests_to(output_dir: str):
|
||||
serializable_tests = []
|
||||
for test in GLOBAL_TEST_REGISTRY:
|
||||
trace = generate_golden_trace(test)
|
||||
module = torch.jit.script(test.program_factory())
|
||||
torchscript_module_bytes = module.save_to_buffer({
|
||||
"annotations.pkl":
|
||||
pickle.dumps(extract_serializable_annotations(module))
|
||||
})
|
||||
serializable_tests.append(
|
||||
SerializableTest(unique_name=test.unique_name,
|
||||
program=torchscript_module_bytes,
|
||||
trace=trace))
|
||||
for test in serializable_tests:
|
||||
with open(os.path.join(output_dir, f"{test.unique_name}.pkl"),
|
||||
"wb") as f:
|
||||
pickle.dump(test, f)
|
||||
|
||||
|
||||
def deserialize_all_tests_from(serialized_test_dir: str):
|
||||
for root, _, files in os.walk(serialized_test_dir):
|
||||
for filename in files:
|
||||
with open(os.path.join(root, filename), 'rb') as f:
|
||||
GLOBAL_TEST_REGISTRY.append(pickle.load(f).as_test())
|
Loading…
Reference in New Issue