mirror of https://github.com/llvm/torch-mlir
Add BertSequenceClassification model to e2e
Use torch tracing to get the module because the original model is not TorchScriptable out of box.pull/345/head
parent
649d6e4f28
commit
89225b0cd8
|
@ -0,0 +1,70 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# Basic BertForSequenceClassification program to classify the input sentence.
|
||||
|
||||
import torch
|
||||
from transformers import BertForSequenceClassification, BertTokenizer
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
torch.manual_seed(0)
|
||||
CLS = "[CLS]"
|
||||
SEP = "[SEP]"
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
||||
def _prepare_sentence_tokens(sentence: str):
|
||||
return torch.tensor([tokenizer.encode(sentence)])
|
||||
|
||||
|
||||
class BasicBertSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = BertForSequenceClassification.from_pretrained(
|
||||
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
|
||||
num_labels=
|
||||
2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=
|
||||
False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=
|
||||
False, # Whether the model returns all hidden-states.
|
||||
torchscript=True)
|
||||
self.model.eval()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
trace_input = {
|
||||
'forward': _prepare_sentence_tokens("how do you like the project")
|
||||
}
|
||||
|
||||
test_input = _prepare_sentence_tokens("this project is very interesting")
|
||||
|
||||
|
||||
def getTracedRecursiveScriptModule():
|
||||
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(),
|
||||
trace_input)
|
||||
script_module = traced_module._actual_script_module
|
||||
export(script_module.forward)
|
||||
annotate_args_decorator = annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
annotate_args_decorator(script_module.forward)
|
||||
return script_module
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
|
||||
def BasicBertSequenceClassification_basic(module, tu: TestUtils):
|
||||
module.forward(test_input)
|
|
@ -20,6 +20,9 @@ mkdir -p $venv_dir
|
|||
mkdir -p $serialized_test_dir
|
||||
python3 -m venv $venv_dir
|
||||
source $venv_dir/bin/activate
|
||||
# For bert_seq_classification
|
||||
python3 -m pip install transformers
|
||||
# For basic_mt
|
||||
python3 -m pip install fairseq fvcore sacremoses subword-nmt
|
||||
|
||||
cd "$torch_mlir_src_root"
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
|
|||
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
||||
|
||||
from . import basic_mt
|
||||
from . import bert_seq_classification
|
||||
|
||||
|
||||
def _get_argparse():
|
||||
|
|
|
@ -563,6 +563,22 @@ def Torch_ConstantStrOp : Torch_Op<"constant.str",
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_ConstantDeviceOp : Torch_Op<"constant.device",
|
||||
[NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
let summary = "Materialize a constant Device value.";
|
||||
let description = [{
|
||||
|
||||
}];
|
||||
let arguments = (ins
|
||||
StrAttr:$value
|
||||
);
|
||||
let results = (outs
|
||||
Torch_DeviceType:$result
|
||||
);
|
||||
let assemblyFormat = "$value attr-dict";
|
||||
}
|
||||
|
||||
def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
|
|
|
@ -738,6 +738,15 @@ void ConstantStrOp::getAsmResultNames(
|
|||
setNameFn(getResult(), "str");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantDeviceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConstantDeviceOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), value());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -156,6 +156,19 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
toMlirNamedAttribute(
|
||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||
c10::attr::value)))));
|
||||
} else if (output->type()->cast<c10::TensorType>()) {
|
||||
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||
op = createMlirOperation(
|
||||
"torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.device", loc,
|
||||
getMlirTypeFromTorchType(loc, output->type()),
|
||||
toMlirNamedAttribute(
|
||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||
c10::attr::value)))));
|
||||
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
||||
torch::jit::Function *function = functionType->function();
|
||||
const std::string &symName = function->qualname().qualifiedName();
|
||||
|
|
Loading…
Reference in New Issue