diff --git a/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py b/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py index 99f5eedfc..c6cf2d310 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py +++ b/build_tools/torchscript_e2e_heavydep_tests/hf_sequence_classification.py @@ -3,9 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -# A pretrained model to classify the input sentence. -# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2 - import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification @@ -15,6 +12,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export torch.manual_seed(0) + def prepare_sentence_tokens(hf_model: str, sentence: str): tokenizer = AutoTokenizer.from_pretrained(hf_model) return torch.tensor([tokenizer.encode(sentence)]) @@ -24,17 +22,16 @@ def getTracedRecursiveScriptModule(module, trace_input): traced_module = torch.jit.trace_module(module, trace_input) script_module = traced_module._actual_script_module export(script_module.forward) - annotate_args_decorator = annotate_args( - [ - None, - ([-1, -1], torch.int64, True), - ] - ) + annotate_args_decorator = annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) annotate_args_decorator(script_module.forward) return script_module class HfSequenceClassification(torch.nn.Module): + def __init__(self, model_name: str): super().__init__() self.model = AutoModelForSequenceClassification.from_pretrained( @@ -50,12 +47,10 @@ class HfSequenceClassification(torch.nn.Module): self.model.eval() @export - @annotate_args( - [ - None, - ([-1, -1], torch.int64, True), - ] - ) + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) def forward(self, tokens): return self.model.forward(tokens)[0] @@ -65,37 +60,122 @@ class HfSequenceClassification(torch.nn.Module): hf_minilm_model = "philschmid/MiniLM-L6-H384-uncased-sst2" trace_input = { - "forward": prepare_sentence_tokens(hf_minilm_model, "how do you like the project") + "forward": + prepare_sentence_tokens(hf_minilm_model, "how do you like the project") } -test_input = prepare_sentence_tokens( - hf_minilm_model, "this project is very interesting") +test_input = prepare_sentence_tokens(hf_minilm_model, + "this project is very interesting") -@register_test_case( - module_factory=lambda: getTracedRecursiveScriptModule( - HfSequenceClassification(hf_minilm_model), trace_input - ) -) +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_minilm_model), trace_input)) def MiniLMSequenceClassification_basic(module, tu: TestUtils): module.forward(test_input) + # ============================================================================== hf_albert_model = "albert-base-v2" trace_input = { - "forward": prepare_sentence_tokens(hf_albert_model, "how do you like the project") + "forward": + prepare_sentence_tokens(hf_albert_model, "how do you like the project") } -test_input = prepare_sentence_tokens( - hf_albert_model, "this project is very interesting") +test_input = prepare_sentence_tokens(hf_albert_model, + "this project is very interesting") -@register_test_case( - module_factory=lambda: getTracedRecursiveScriptModule( - HfSequenceClassification(hf_albert_model), trace_input - ) -) +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_albert_model), trace_input)) def AlbertSequenceClassification_basic(module, tu: TestUtils): module.forward(test_input) + # ============================================================================== + +hf_bart_model = "facebook/bart-base" + +trace_input = { + "forward": + prepare_sentence_tokens(hf_bart_model, "how do you like the project") +} +test_input = prepare_sentence_tokens(hf_bart_model, + "this project is very interesting") + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_bart_model), trace_input)) +def BartSequenceClassification_basic(module, tu: TestUtils): + module.forward(test_input) + + +# ============================================================================== + +hf_bert_model = "bert-base-uncased" + +trace_input = { + "forward": + prepare_sentence_tokens(hf_bert_model, "how do you like the project") +} +test_input = prepare_sentence_tokens(hf_bert_model, + "this project is very interesting") + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_bert_model), trace_input)) +def BertSequenceClassification_basic(module, tu: TestUtils): + module.forward(test_input) + + +# ============================================================================== + +hf_bigbird_model = "google/bigbird-roberta-base" + +trace_input = { + "forward": + prepare_sentence_tokens(hf_bigbird_model, "how do you like the project") +} +test_input = prepare_sentence_tokens(hf_bigbird_model, + "this project is very interesting") + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_bigbird_model), trace_input)) +def BigBirdSequenceClassification_basic(module, tu: TestUtils): + module.forward(test_input) + + +# ============================================================================== + +hf_distilbert_model = "distilbert-base-uncased" + +trace_input = { + "forward": + prepare_sentence_tokens(hf_distilbert_model, "how do you like the project") +} +test_input = prepare_sentence_tokens(hf_distilbert_model, + "this project is very interesting") + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_distilbert_model), trace_input)) +def DistilBertSequenceClassification_basic(module, tu: TestUtils): + module.forward(test_input) + + +# ============================================================================== + +hf_gpt2_model = "gpt2" + +trace_input = { + "forward": + prepare_sentence_tokens(hf_gpt2_model, "how do you like the project") +} +test_input = prepare_sentence_tokens(hf_gpt2_model, + "this project is very interesting") + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + HfSequenceClassification(hf_gpt2_model), trace_input)) +def GPT2SequenceClassification_basic(module, tu: TestUtils): + module.forward(test_input) diff --git a/build_tools/torchscript_e2e_heavydep_tests/main.py b/build_tools/torchscript_e2e_heavydep_tests/main.py index 35df5b0c3..2f4dbdd67 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/main.py +++ b/build_tools/torchscript_e2e_heavydep_tests/main.py @@ -10,6 +10,7 @@ from torch_mlir_e2e_test.torchscript.serialization import serialize_all_tests_to from . import hf_sequence_classification from . import fully_connected_backward from . import bert_functorch +from . import vision_models def _get_argparse(): diff --git a/build_tools/torchscript_e2e_heavydep_tests/vision_models.py b/build_tools/torchscript_e2e_heavydep_tests/vision_models.py new file mode 100644 index 000000000..54ad129af --- /dev/null +++ b/build_tools/torchscript_e2e_heavydep_tests/vision_models.py @@ -0,0 +1,265 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch +import torchvision.models as models + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +torch.manual_seed(0) + + +def getTracedRecursiveScriptModule(module): + script_module = torch.jit.script(module) + export(script_module.forward) + annotate_args_decorator = annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + annotate_args_decorator(script_module.forward) + return script_module + + +class VisionModule(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + self.train(False) + + def forward(self, input): + return self.model.forward(input) + + +# ============================================================================== + +resnet18_model = models.resnet18(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(resnet18_model))) +def Resnet18VisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +resnext50_32x4d_model = models.resnext50_32x4d(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(resnext50_32x4d_model))) +def Resnext50_32x4dVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +mnasnet1_0_model = models.mnasnet1_0(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(mnasnet1_0_model))) +def Mnasnet1_0VisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +alexnet_model = models.alexnet(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(alexnet_model))) +def AlexnetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +shufflenet_model = models.shufflenet_v2_x1_0(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(shufflenet_model))) +def ShufflenetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +squeezenet_model = models.squeezenet1_0(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(squeezenet_model))) +def SqueezenetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +vgg16_model = models.vgg16(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(vgg16_model))) +def Vgg16VisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +wide_resnet_model = models.wide_resnet50_2(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(wide_resnet_model))) +def Wide_ResnetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +efficientnet_model = models.efficientnet_b0(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(efficientnet_model))) +def EfficientnetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +mobilenet_v2_model = models.mobilenet_v2(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(mobilenet_v2_model))) +def Mobilenet_v2VisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +mobilenet_v3_large_model = models.mobilenet_v3_large(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(mobilenet_v3_large_model))) +def Mobilenet_v3_largeVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +resnet50_model = models.resnet50(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(resnet50_model))) +def Resnet50VisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +densenet121_model = models.densenet121(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(densenet121_model))) +def Densenet121VisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +timm_regnet_model = models.regnet_y_1_6gf(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(timm_regnet_model))) +def Timm_RegnetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +pytorch_unet_model = torch.hub.load( + "mateuszbuda/brain-segmentation-pytorch", + "unet", + in_channels=3, + out_channels=1, + init_features=32, + pretrained=True, +) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(pytorch_unet_model))) +def PytorchUnetVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +resnest_model = torch.hub.load('zhanghang1989/ResNeSt', + 'resnest50', + pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(resnest_model))) +def ResnestVisionModel_basic(module, tu: TestUtils): + module.forward(input) + + +# ============================================================================== + +timm_vit_model = models.vit_b_16(pretrained=True) + +input = torch.randn(1, 3, 224, 224) + + +@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule( + VisionModule(timm_vit_model))) +def ViTVisionModel_basic(module, tu: TestUtils): + module.forward(input)