diff --git a/build_tools/torchscript_e2e_heavydep_tests/main.py b/build_tools/torchscript_e2e_heavydep_tests/main.py index 05341b92a..c23f28e92 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/main.py +++ b/build_tools/torchscript_e2e_heavydep_tests/main.py @@ -14,7 +14,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations from . import basic_mt -from . import bert_seq_classification +from . import minilm_sequence_classification def _get_argparse(): diff --git a/build_tools/torchscript_e2e_heavydep_tests/bert_seq_classification.py b/build_tools/torchscript_e2e_heavydep_tests/minilm_sequence_classification.py similarity index 74% rename from build_tools/torchscript_e2e_heavydep_tests/bert_seq_classification.py rename to build_tools/torchscript_e2e_heavydep_tests/minilm_sequence_classification.py index 41e7c2bdf..46e6f69a0 100644 --- a/build_tools/torchscript_e2e_heavydep_tests/bert_seq_classification.py +++ b/build_tools/torchscript_e2e_heavydep_tests/minilm_sequence_classification.py @@ -3,30 +3,29 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -# Basic BertForSequenceClassification program to classify the input sentence. +# A pretrained model to classify the input sentence. +# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2 import torch -from transformers import BertForSequenceClassification, BertTokenizer +from transformers import AutoTokenizer, AutoModelForSequenceClassification from torch_mlir_e2e_test.torchscript.framework import TestUtils from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export torch.manual_seed(0) -CLS = "[CLS]" -SEP = "[SEP]" -tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +tokenizer = AutoTokenizer.from_pretrained("philschmid/MiniLM-L6-H384-uncased-sst2") def _prepare_sentence_tokens(sentence: str): return torch.tensor([tokenizer.encode(sentence)]) -class BasicBertSequenceClassification(torch.nn.Module): +class MiniLMSequenceClassification(torch.nn.Module): def __init__(self): super().__init__() - self.model = BertForSequenceClassification.from_pretrained( - "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. + self.model = AutoModelForSequenceClassification.from_pretrained( + "philschmid/MiniLM-L6-H384-uncased-sst2", # The pretrained model. num_labels= 2, # The number of output labels--2 for binary classification. output_attentions= @@ -53,7 +52,7 @@ test_input = _prepare_sentence_tokens("this project is very interesting") def getTracedRecursiveScriptModule(): - traced_module = torch.jit.trace_module(BasicBertSequenceClassification(), + traced_module = torch.jit.trace_module(MiniLMSequenceClassification(), trace_input) script_module = traced_module._actual_script_module export(script_module.forward) @@ -66,5 +65,5 @@ def getTracedRecursiveScriptModule(): @register_test_case(module_factory=lambda: getTracedRecursiveScriptModule()) -def BasicBertSequenceClassification_basic(module, tu: TestUtils): +def MiniLMSequenceClassification_basic(module, tu: TestUtils): module.forward(test_input)