mirror of https://github.com/llvm/torch-mlir
E2e for MiniLM-L6-H384-uncased-sst2
Replace the original BertSequenceClassification with this new one. The ops needed to support are identical.pull/351/head
parent
c3e0a1e1dc
commit
fadd76e9b8
|
@ -14,7 +14,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
|
|||
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
||||
|
||||
from . import basic_mt
|
||||
from . import bert_seq_classification
|
||||
from . import minilm_sequence_classification
|
||||
|
||||
|
||||
def _get_argparse():
|
||||
|
|
|
@ -3,30 +3,29 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# Basic BertForSequenceClassification program to classify the input sentence.
|
||||
# A pretrained model to classify the input sentence.
|
||||
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
|
||||
|
||||
import torch
|
||||
from transformers import BertForSequenceClassification, BertTokenizer
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
torch.manual_seed(0)
|
||||
CLS = "[CLS]"
|
||||
SEP = "[SEP]"
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
tokenizer = AutoTokenizer.from_pretrained("philschmid/MiniLM-L6-H384-uncased-sst2")
|
||||
|
||||
|
||||
def _prepare_sentence_tokens(sentence: str):
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class BasicBertSequenceClassification(torch.nn.Module):
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = BertForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"philschmid/MiniLM-L6-H384-uncased-sst2", # The pretrained model.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
|
@ -53,7 +52,7 @@ test_input = _prepare_sentence_tokens("this project is very interesting")
|
|||
|
||||
|
||||
def getTracedRecursiveScriptModule():
|
||||
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(),
|
||||
traced_module = torch.jit.trace_module(MiniLMSequenceClassification(),
|
||||
trace_input)
|
||||
script_module = traced_module._actual_script_module
|
||||
export(script_module.forward)
|
||||
|
@ -66,5 +65,5 @@ def getTracedRecursiveScriptModule():
|
|||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
|
||||
def BasicBertSequenceClassification_basic(module, tu: TestUtils):
|
||||
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
Loading…
Reference in New Issue