mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add heavydep tests for torch benchmarks
This commit adds e2e heavydep tests for the torch benchmarks. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/798/head snapshot-20220426.415
parent
2877a37ac6
commit
4635d36efb
|
@ -3,9 +3,6 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# A pretrained model to classify the input sentence.
|
||||
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
|
@ -15,6 +12,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def prepare_sentence_tokens(hf_model: str, sentence: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
@ -24,17 +22,16 @@ def getTracedRecursiveScriptModule(module, trace_input):
|
|||
traced_module = torch.jit.trace_module(module, trace_input)
|
||||
script_module = traced_module._actual_script_module
|
||||
export(script_module.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
]
|
||||
)
|
||||
annotate_args_decorator = annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
annotate_args_decorator(script_module.forward)
|
||||
return script_module
|
||||
|
||||
|
||||
class HfSequenceClassification(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
|
@ -50,12 +47,10 @@ class HfSequenceClassification(torch.nn.Module):
|
|||
self.model.eval()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
]
|
||||
)
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
@ -65,37 +60,122 @@ class HfSequenceClassification(torch.nn.Module):
|
|||
hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2"
|
||||
|
||||
trace_input = {
|
||||
"forward": prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(
|
||||
hf_minilm_model, "this project is very interesting")
|
||||
test_input = prepare_sentence_tokens(hf_minilm_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_minilm_model), trace_input
|
||||
)
|
||||
)
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_minilm_model), trace_input))
|
||||
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_albert_model = "albert-base-v2"
|
||||
|
||||
trace_input = {
|
||||
"forward": prepare_sentence_tokens(hf_albert_model, "how do you like the project")
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_albert_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(
|
||||
hf_albert_model, "this project is very interesting")
|
||||
test_input = prepare_sentence_tokens(hf_albert_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_albert_model), trace_input
|
||||
)
|
||||
)
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_albert_model), trace_input))
|
||||
def AlbertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_bart_model = "facebook/bart-base"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_bart_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_bart_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_bart_model), trace_input))
|
||||
def BartSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_bert_model = "bert-base-uncased"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_bert_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_bert_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_bert_model), trace_input))
|
||||
def BertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_bigbird_model = "google/bigbird-roberta-base"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_bigbird_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_bigbird_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_bigbird_model), trace_input))
|
||||
def BigBirdSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_distilbert_model = "distilbert-base-uncased"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_distilbert_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_distilbert_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_distilbert_model), trace_input))
|
||||
def DistilBertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
hf_gpt2_model = "gpt2"
|
||||
|
||||
trace_input = {
|
||||
"forward":
|
||||
prepare_sentence_tokens(hf_gpt2_model, "how do you like the project")
|
||||
}
|
||||
test_input = prepare_sentence_tokens(hf_gpt2_model,
|
||||
"this project is very interesting")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
HfSequenceClassification(hf_gpt2_model), trace_input))
|
||||
def GPT2SequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
||||
|
|
|
@ -10,6 +10,7 @@ from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
|
|||
from . import hf_sequence_classification
|
||||
from . import fully_connected_backward
|
||||
from . import bert_functorch
|
||||
from . import vision_models
|
||||
|
||||
|
||||
def _get_argparse():
|
||||
|
|
|
@ -0,0 +1,265 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def getTracedRecursiveScriptModule(module):
|
||||
script_module = torch.jit.script(module)
|
||||
export(script_module.forward)
|
||||
annotate_args_decorator = annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
annotate_args_decorator(script_module.forward)
|
||||
return script_module
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnet18_model = models.resnet18(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnet18_model)))
|
||||
def Resnet18VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnext50_32x4d_model)))
|
||||
def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
mnasnet1_0_model = models.mnasnet1_0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(mnasnet1_0_model)))
|
||||
def Mnasnet1_0VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
alexnet_model = models.alexnet(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(alexnet_model)))
|
||||
def AlexnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(shufflenet_model)))
|
||||
def ShufflenetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
squeezenet_model = models.squeezenet1_0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(squeezenet_model)))
|
||||
def SqueezenetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
vgg16_model = models.vgg16(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(vgg16_model)))
|
||||
def Vgg16VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
wide_resnet_model = models.wide_resnet50_2(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(wide_resnet_model)))
|
||||
def Wide_ResnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
efficientnet_model = models.efficientnet_b0(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(efficientnet_model)))
|
||||
def EfficientnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
mobilenet_v2_model = models.mobilenet_v2(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(mobilenet_v2_model)))
|
||||
def Mobilenet_v2VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(mobilenet_v3_large_model)))
|
||||
def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnet50_model = models.resnet50(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnet50_model)))
|
||||
def Resnet50VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
densenet121_model = models.densenet121(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(densenet121_model)))
|
||||
def Densenet121VisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
timm_regnet_model = models.regnet_y_1_6gf(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(timm_regnet_model)))
|
||||
def Timm_RegnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
pytorch_unet_model = torch.hub.load(
|
||||
"mateuszbuda/brain-segmentation-pytorch",
|
||||
"unet",
|
||||
in_channels=3,
|
||||
out_channels=1,
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(pytorch_unet_model)))
|
||||
def PytorchUnetVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
resnest_model = torch.hub.load('zhanghang1989/ResNeSt',
|
||||
'resnest50',
|
||||
pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(resnest_model)))
|
||||
def ResnestVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
timm_vit_model = models.vit_b_16(pretrained=True)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||
VisionModule(timm_vit_model)))
|
||||
def ViTVisionModel_basic(module, tu: TestUtils):
|
||||
module.forward(input)
|
Loading…
Reference in New Issue