E2e for MiniLM-L6-H384-uncased-sst2

Replace the original BertSequenceClassification with this new one.
The ops needed to support are identical.
pull/351/head
Yi Zhang 2021-10-04 15:18:07 -04:00
parent c3e0a1e1dc
commit fadd76e9b8
2 changed files with 10 additions and 11 deletions

View File

@ -14,7 +14,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
from . import basic_mt from . import basic_mt
from . import bert_seq_classification from . import minilm_sequence_classification
def _get_argparse(): def _get_argparse():

View File

@ -3,30 +3,29 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
# Basic BertForSequenceClassification program to classify the input sentence. # A pretrained model to classify the input sentence.
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
import torch import torch
from transformers import BertForSequenceClassification, BertTokenizer from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch_mlir_e2e_test.torchscript.framework import TestUtils from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
torch.manual_seed(0) torch.manual_seed(0)
CLS = "[CLS]" tokenizer = AutoTokenizer.from_pretrained("philschmid/MiniLM-L6-H384-uncased-sst2")
SEP = "[SEP]"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def _prepare_sentence_tokens(sentence: str): def _prepare_sentence_tokens(sentence: str):
return torch.tensor([tokenizer.encode(sentence)]) return torch.tensor([tokenizer.encode(sentence)])
class BasicBertSequenceClassification(torch.nn.Module): class MiniLMSequenceClassification(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.model = BertForSequenceClassification.from_pretrained( self.model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. "philschmid/MiniLM-L6-H384-uncased-sst2", # The pretrained model.
num_labels= num_labels=
2, # The number of output labels--2 for binary classification. 2, # The number of output labels--2 for binary classification.
output_attentions= output_attentions=
@ -53,7 +52,7 @@ test_input = _prepare_sentence_tokens("this project is very interesting")
def getTracedRecursiveScriptModule(): def getTracedRecursiveScriptModule():
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(), traced_module = torch.jit.trace_module(MiniLMSequenceClassification(),
trace_input) trace_input)
script_module = traced_module._actual_script_module script_module = traced_module._actual_script_module
export(script_module.forward) export(script_module.forward)
@ -66,5 +65,5 @@ def getTracedRecursiveScriptModule():
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule()) @register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
def BasicBertSequenceClassification_basic(module, tu: TestUtils): def MiniLMSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input) module.forward(test_input)