mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add heavydep tests for torch benchmarks
This commit adds e2e heavydep tests for the torch benchmarks. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/798/head snapshot-20220426.415
parent
2877a37ac6
commit
4635d36efb
|
@ -3,9 +3,6 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
# A pretrained model to classify the input sentence.
|
|
||||||
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
|
||||||
|
@ -15,6 +12,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
def prepare_sentence_tokens(hf_model: str, sentence: str):
|
def prepare_sentence_tokens(hf_model: str, sentence: str):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
tokenizer = AutoTokenizer.from_pretrained(hf_model)
|
||||||
return torch.tensor([tokenizer.encode(sentence)])
|
return torch.tensor([tokenizer.encode(sentence)])
|
||||||
|
@ -24,17 +22,16 @@ def getTracedRecursiveScriptModule(module, trace_input):
|
||||||
traced_module = torch.jit.trace_module(module, trace_input)
|
traced_module = torch.jit.trace_module(module, trace_input)
|
||||||
script_module = traced_module._actual_script_module
|
script_module = traced_module._actual_script_module
|
||||||
export(script_module.forward)
|
export(script_module.forward)
|
||||||
annotate_args_decorator = annotate_args(
|
annotate_args_decorator = annotate_args([
|
||||||
[
|
None,
|
||||||
None,
|
([-1, -1], torch.int64, True),
|
||||||
([-1, -1], torch.int64, True),
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
annotate_args_decorator(script_module.forward)
|
annotate_args_decorator(script_module.forward)
|
||||||
return script_module
|
return script_module
|
||||||
|
|
||||||
|
|
||||||
class HfSequenceClassification(torch.nn.Module):
|
class HfSequenceClassification(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
@ -50,12 +47,10 @@ class HfSequenceClassification(torch.nn.Module):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args(
|
@annotate_args([
|
||||||
[
|
None,
|
||||||
None,
|
([-1, -1], torch.int64, True),
|
||||||
([-1, -1], torch.int64, True),
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
return self.model.forward(tokens)[0]
|
return self.model.forward(tokens)[0]
|
||||||
|
|
||||||
|
@ -65,37 +60,122 @@ class HfSequenceClassification(torch.nn.Module):
|
||||||
hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2"
|
hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2"
|
||||||
|
|
||||||
trace_input = {
|
trace_input = {
|
||||||
"forward": prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
|
||||||
}
|
}
|
||||||
test_input = prepare_sentence_tokens(
|
test_input = prepare_sentence_tokens(hf_minilm_model,
|
||||||
hf_minilm_model, "this project is very interesting")
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
module_factory=lambda: getTracedRecursiveScriptModule(
|
HfSequenceClassification(hf_minilm_model), trace_input))
|
||||||
HfSequenceClassification(hf_minilm_model), trace_input
|
|
||||||
)
|
|
||||||
)
|
|
||||||
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
|
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
|
||||||
module.forward(test_input)
|
module.forward(test_input)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
hf_albert_model = "albert-base-v2"
|
hf_albert_model = "albert-base-v2"
|
||||||
|
|
||||||
trace_input = {
|
trace_input = {
|
||||||
"forward": prepare_sentence_tokens(hf_albert_model, "how do you like the project")
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_albert_model, "how do you like the project")
|
||||||
}
|
}
|
||||||
test_input = prepare_sentence_tokens(
|
test_input = prepare_sentence_tokens(hf_albert_model,
|
||||||
hf_albert_model, "this project is very interesting")
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
module_factory=lambda: getTracedRecursiveScriptModule(
|
HfSequenceClassification(hf_albert_model), trace_input))
|
||||||
HfSequenceClassification(hf_albert_model), trace_input
|
|
||||||
)
|
|
||||||
)
|
|
||||||
def AlbertSequenceClassification_basic(module, tu: TestUtils):
|
def AlbertSequenceClassification_basic(module, tu: TestUtils):
|
||||||
module.forward(test_input)
|
module.forward(test_input)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
hf_bart_model = "facebook/bart-base"
|
||||||
|
|
||||||
|
trace_input = {
|
||||||
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_bart_model, "how do you like the project")
|
||||||
|
}
|
||||||
|
test_input = prepare_sentence_tokens(hf_bart_model,
|
||||||
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
HfSequenceClassification(hf_bart_model), trace_input))
|
||||||
|
def BartSequenceClassification_basic(module, tu: TestUtils):
|
||||||
|
module.forward(test_input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
hf_bert_model = "bert-base-uncased"
|
||||||
|
|
||||||
|
trace_input = {
|
||||||
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_bert_model, "how do you like the project")
|
||||||
|
}
|
||||||
|
test_input = prepare_sentence_tokens(hf_bert_model,
|
||||||
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
HfSequenceClassification(hf_bert_model), trace_input))
|
||||||
|
def BertSequenceClassification_basic(module, tu: TestUtils):
|
||||||
|
module.forward(test_input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
hf_bigbird_model = "google/bigbird-roberta-base"
|
||||||
|
|
||||||
|
trace_input = {
|
||||||
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_bigbird_model, "how do you like the project")
|
||||||
|
}
|
||||||
|
test_input = prepare_sentence_tokens(hf_bigbird_model,
|
||||||
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
HfSequenceClassification(hf_bigbird_model), trace_input))
|
||||||
|
def BigBirdSequenceClassification_basic(module, tu: TestUtils):
|
||||||
|
module.forward(test_input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
hf_distilbert_model = "distilbert-base-uncased"
|
||||||
|
|
||||||
|
trace_input = {
|
||||||
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_distilbert_model, "how do you like the project")
|
||||||
|
}
|
||||||
|
test_input = prepare_sentence_tokens(hf_distilbert_model,
|
||||||
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
HfSequenceClassification(hf_distilbert_model), trace_input))
|
||||||
|
def DistilBertSequenceClassification_basic(module, tu: TestUtils):
|
||||||
|
module.forward(test_input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
hf_gpt2_model = "gpt2"
|
||||||
|
|
||||||
|
trace_input = {
|
||||||
|
"forward":
|
||||||
|
prepare_sentence_tokens(hf_gpt2_model, "how do you like the project")
|
||||||
|
}
|
||||||
|
test_input = prepare_sentence_tokens(hf_gpt2_model,
|
||||||
|
"this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
HfSequenceClassification(hf_gpt2_model), trace_input))
|
||||||
|
def GPT2SequenceClassification_basic(module, tu: TestUtils):
|
||||||
|
module.forward(test_input)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
|
||||||
from . import hf_sequence_classification
|
from . import hf_sequence_classification
|
||||||
from . import fully_connected_backward
|
from . import fully_connected_backward
|
||||||
from . import bert_functorch
|
from . import bert_functorch
|
||||||
|
from . import vision_models
|
||||||
|
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
|
|
|
@ -0,0 +1,265 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
def getTracedRecursiveScriptModule(module):
|
||||||
|
script_module = torch.jit.script(module)
|
||||||
|
export(script_module.forward)
|
||||||
|
annotate_args_decorator = annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
annotate_args_decorator(script_module.forward)
|
||||||
|
return script_module
|
||||||
|
|
||||||
|
|
||||||
|
class VisionModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.train(False)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.model.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
resnet18_model = models.resnet18(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(resnet18_model)))
|
||||||
|
def Resnet18VisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(resnext50_32x4d_model)))
|
||||||
|
def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
mnasnet1_0_model = models.mnasnet1_0(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(mnasnet1_0_model)))
|
||||||
|
def Mnasnet1_0VisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
alexnet_model = models.alexnet(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(alexnet_model)))
|
||||||
|
def AlexnetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(shufflenet_model)))
|
||||||
|
def ShufflenetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
squeezenet_model = models.squeezenet1_0(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(squeezenet_model)))
|
||||||
|
def SqueezenetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
vgg16_model = models.vgg16(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(vgg16_model)))
|
||||||
|
def Vgg16VisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
wide_resnet_model = models.wide_resnet50_2(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(wide_resnet_model)))
|
||||||
|
def Wide_ResnetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
efficientnet_model = models.efficientnet_b0(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(efficientnet_model)))
|
||||||
|
def EfficientnetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
mobilenet_v2_model = models.mobilenet_v2(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(mobilenet_v2_model)))
|
||||||
|
def Mobilenet_v2VisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(mobilenet_v3_large_model)))
|
||||||
|
def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
resnet50_model = models.resnet50(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(resnet50_model)))
|
||||||
|
def Resnet50VisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
densenet121_model = models.densenet121(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(densenet121_model)))
|
||||||
|
def Densenet121VisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
timm_regnet_model = models.regnet_y_1_6gf(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(timm_regnet_model)))
|
||||||
|
def Timm_RegnetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
pytorch_unet_model = torch.hub.load(
|
||||||
|
"mateuszbuda/brain-segmentation-pytorch",
|
||||||
|
"unet",
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=1,
|
||||||
|
init_features=32,
|
||||||
|
pretrained=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(pytorch_unet_model)))
|
||||||
|
def PytorchUnetVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
resnest_model = torch.hub.load('zhanghang1989/ResNeSt',
|
||||||
|
'resnest50',
|
||||||
|
pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(resnest_model)))
|
||||||
|
def ResnestVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
timm_vit_model = models.vit_b_16(pretrained=True)
|
||||||
|
|
||||||
|
input = torch.randn(1, 3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
|
||||||
|
VisionModule(timm_vit_model)))
|
||||||
|
def ViTVisionModel_basic(module, tu: TestUtils):
|
||||||
|
module.forward(input)
|
Loading…
Reference in New Issue