[MLIR][TORCH] Add heavydep tests for torch benchmarks

This commit adds e2e heavydep tests for the torch benchmarks.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/798/head snapshot-20220426.415
Vivek Khandelwal 2022-04-21 18:19:24 +05:30
parent 2877a37ac6
commit 4635d36efb
3 changed files with 377 additions and 31 deletions

View File

@ -3,9 +3,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# A pretrained model to classify the input sentence.
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@ -15,6 +12,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
torch.manual_seed(0)
def prepare_sentence_tokens(hf_model: str, sentence: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
return torch.tensor([tokenizer.encode(sentence)])
@ -24,17 +22,16 @@ def getTracedRecursiveScriptModule(module, trace_input):
traced_module = torch.jit.trace_module(module, trace_input)
script_module = traced_module._actual_script_module
export(script_module.forward)
annotate_args_decorator = annotate_args(
[
None,
([-1, -1], torch.int64, True),
]
)
annotate_args_decorator = annotate_args([
None,
([-1, -1], torch.int64, True),
])
annotate_args_decorator(script_module.forward)
return script_module
class HfSequenceClassification(torch.nn.Module):
def __init__(self, model_name: str):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
@ -50,12 +47,10 @@ class HfSequenceClassification(torch.nn.Module):
self.model.eval()
@export
@annotate_args(
[
None,
([-1, -1], torch.int64, True),
]
)
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, tokens):
return self.model.forward(tokens)[0]
@ -65,37 +60,122 @@ class HfSequenceClassification(torch.nn.Module):
hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2"
trace_input = {
"forward": prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
"forward":
prepare_sentence_tokens(hf_minilm_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(
hf_minilm_model, "this project is very interesting")
test_input = prepare_sentence_tokens(hf_minilm_model,
"this project is very interesting")
@register_test_case(
module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_minilm_model), trace_input
)
)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_minilm_model), trace_input))
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_albert_model = "albert-base-v2"
trace_input = {
"forward": prepare_sentence_tokens(hf_albert_model, "how do you like the project")
"forward":
prepare_sentence_tokens(hf_albert_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(
hf_albert_model, "this project is very interesting")
test_input = prepare_sentence_tokens(hf_albert_model,
"this project is very interesting")
@register_test_case(
module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_albert_model), trace_input
)
)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_albert_model), trace_input))
def AlbertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_bart_model = "facebook/bart-base"
trace_input = {
"forward":
prepare_sentence_tokens(hf_bart_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_bart_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_bart_model), trace_input))
def BartSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_bert_model = "bert-base-uncased"
trace_input = {
"forward":
prepare_sentence_tokens(hf_bert_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_bert_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_bert_model), trace_input))
def BertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_bigbird_model = "google/bigbird-roberta-base"
trace_input = {
"forward":
prepare_sentence_tokens(hf_bigbird_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_bigbird_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_bigbird_model), trace_input))
def BigBirdSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_distilbert_model = "distilbert-base-uncased"
trace_input = {
"forward":
prepare_sentence_tokens(hf_distilbert_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_distilbert_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_distilbert_model), trace_input))
def DistilBertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)
# ==============================================================================
hf_gpt2_model = "gpt2"
trace_input = {
"forward":
prepare_sentence_tokens(hf_gpt2_model, "how do you like the project")
}
test_input = prepare_sentence_tokens(hf_gpt2_model,
"this project is very interesting")
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
HfSequenceClassification(hf_gpt2_model), trace_input))
def GPT2SequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)

View File

@ -10,6 +10,7 @@ from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to
from . import hf_sequence_classification
from . import fully_connected_backward
from . import bert_functorch
from . import vision_models
def _get_argparse():

View File

@ -0,0 +1,265 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
import torchvision.models as models
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
torch.manual_seed(0)
def getTracedRecursiveScriptModule(module):
script_module = torch.jit.script(module)
export(script_module.forward)
annotate_args_decorator = annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
annotate_args_decorator(script_module.forward)
return script_module
class VisionModule(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.train(False)
def forward(self, input):
return self.model.forward(input)
# ==============================================================================
resnet18_model = models.resnet18(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnet18_model)))
def Resnet18VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnext50_32x4d_model)))
def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
mnasnet1_0_model = models.mnasnet1_0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(mnasnet1_0_model)))
def Mnasnet1_0VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
alexnet_model = models.alexnet(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(alexnet_model)))
def AlexnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(shufflenet_model)))
def ShufflenetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
squeezenet_model = models.squeezenet1_0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(squeezenet_model)))
def SqueezenetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
vgg16_model = models.vgg16(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(vgg16_model)))
def Vgg16VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
wide_resnet_model = models.wide_resnet50_2(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(wide_resnet_model)))
def Wide_ResnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
efficientnet_model = models.efficientnet_b0(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(efficientnet_model)))
def EfficientnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
mobilenet_v2_model = models.mobilenet_v2(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(mobilenet_v2_model)))
def Mobilenet_v2VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(mobilenet_v3_large_model)))
def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
resnet50_model = models.resnet50(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnet50_model)))
def Resnet50VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
densenet121_model = models.densenet121(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(densenet121_model)))
def Densenet121VisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
timm_regnet_model = models.regnet_y_1_6gf(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(timm_regnet_model)))
def Timm_RegnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
pytorch_unet_model = torch.hub.load(
"mateuszbuda/brain-segmentation-pytorch",
"unet",
in_channels=3,
out_channels=1,
init_features=32,
pretrained=True,
)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(pytorch_unet_model)))
def PytorchUnetVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
resnest_model = torch.hub.load('zhanghang1989/ResNeSt',
'resnest50',
pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(resnest_model)))
def ResnestVisionModel_basic(module, tu: TestUtils):
module.forward(input)
# ==============================================================================
timm_vit_model = models.vit_b_16(pretrained=True)
input = torch.randn(1, 3, 224, 224)
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule(
VisionModule(timm_vit_model)))
def ViTVisionModel_basic(module, tu: TestUtils):
module.forward(input)