mirror of https://github.com/llvm/torch-mlir
E2e for MiniLM-L6-H384-uncased-sst2
Replace the original BertSequenceClassification with this new one. The ops needed to support are identical.pull/351/head
parent
c3e0a1e1dc
commit
fadd76e9b8
|
@ -14,7 +14,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
|
||||||
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
||||||
|
|
||||||
from . import basic_mt
|
from . import basic_mt
|
||||||
from . import bert_seq_classification
|
from . import minilm_sequence_classification
|
||||||
|
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
|
|
|
@ -3,30 +3,29 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
# Basic BertForSequenceClassification program to classify the input sentence.
|
# A pretrained model to classify the input sentence.
|
||||||
|
# https://huggingface.co/philschmid/MiniLM-L6-H384-uncased-sst2
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import BertForSequenceClassification, BertTokenizer
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||||
|
|
||||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
CLS = "[CLS]"
|
tokenizer = AutoTokenizer.from_pretrained("philschmid/MiniLM-L6-H384-uncased-sst2")
|
||||||
SEP = "[SEP]"
|
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_sentence_tokens(sentence: str):
|
def _prepare_sentence_tokens(sentence: str):
|
||||||
return torch.tensor([tokenizer.encode(sentence)])
|
return torch.tensor([tokenizer.encode(sentence)])
|
||||||
|
|
||||||
|
|
||||||
class BasicBertSequenceClassification(torch.nn.Module):
|
class MiniLMSequenceClassification(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = BertForSequenceClassification.from_pretrained(
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
|
"philschmid/MiniLM-L6-H384-uncased-sst2", # The pretrained model.
|
||||||
num_labels=
|
num_labels=
|
||||||
2, # The number of output labels--2 for binary classification.
|
2, # The number of output labels--2 for binary classification.
|
||||||
output_attentions=
|
output_attentions=
|
||||||
|
@ -53,7 +52,7 @@ test_input = _prepare_sentence_tokens("this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
def getTracedRecursiveScriptModule():
|
def getTracedRecursiveScriptModule():
|
||||||
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(),
|
traced_module = torch.jit.trace_module(MiniLMSequenceClassification(),
|
||||||
trace_input)
|
trace_input)
|
||||||
script_module = traced_module._actual_script_module
|
script_module = traced_module._actual_script_module
|
||||||
export(script_module.forward)
|
export(script_module.forward)
|
||||||
|
@ -66,5 +65,5 @@ def getTracedRecursiveScriptModule():
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
|
||||||
def BasicBertSequenceClassification_basic(module, tu: TestUtils):
|
def MiniLMSequenceClassification_basic(module, tu: TestUtils):
|
||||||
module.forward(test_input)
|
module.forward(test_input)
|
Loading…
Reference in New Issue