E2e for MiniLM-L6-H384-uncased-sst2

Replace the original BertSequenceClassification with this new one.
The ops needed to support are identical.
pull/351/head
Yi Zhang 2021-10-04 15:18:07 -04:00
parent c3e0a1e1dc
commit fadd76e9b8
2 changed files with 10 additions and 11 deletions

View File

@ -14,7 +14,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
from . import basic_mt
from . import bert_seq_classification
from . import minilm_sequence_classification
def _get_argparse():

View File

@ -3,30 +3,29 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# Basic BertForSequenceClassification program to classify the input sentence.
# A pretrained model to classify the input sentence.
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
torch.manual_seed(0)
CLS = "[CLS]"
SEP = "[SEP]"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("philschmid/MiniLM-L6-H384-uncased-sst2")
def _prepare_sentence_tokens(sentence: str):
return torch.tensor([tokenizer.encode(sentence)])
class BasicBertSequenceClassification(torch.nn.Module):
class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
self.model = AutoModelForSequenceClassification.from_pretrained(
"philschmid/MiniLM-L6-H384-uncased-sst2", # The pretrained model.
num_labels=
2, # The number of output labels--2 for binary classification.
output_attentions=
@ -53,7 +52,7 @@ test_input = _prepare_sentence_tokens("this project is very interesting")
def getTracedRecursiveScriptModule():
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(),
traced_module = torch.jit.trace_module(MiniLMSequenceClassification(),
trace_input)
script_module = traced_module._actual_script_module
export(script_module.forward)
@ -66,5 +65,5 @@ def getTracedRecursiveScriptModule():
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
def BasicBertSequenceClassification_basic(module, tu: TestUtils):
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)