Add BertSequenceClassification model to e2e

Use torch tracing to get the module because the original model is not
TorchScriptable out of box.
pull/345/head
Yi Zhang 2021-09-28 10:56:08 -04:00
parent 649d6e4f28
commit 89225b0cd8
6 changed files with 112 additions and 0 deletions

View File

@ -0,0 +1,70 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# Basic BertForSequenceClassification program to classify the input sentence.
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
torch.manual_seed(0)
CLS = "[CLS]"
SEP = "[SEP]"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def _prepare_sentence_tokens(sentence: str):
return torch.tensor([tokenizer.encode(sentence)])
class BasicBertSequenceClassification(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
num_labels=
2, # The number of output labels--2 for binary classification.
output_attentions=
False, # Whether the model returns attentions weights.
output_hidden_states=
False, # Whether the model returns all hidden-states.
torchscript=True)
self.model.eval()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, tokens):
return self.model.forward(tokens)[0]
trace_input = {
'forward': _prepare_sentence_tokens("how do you like the project")
}
test_input = _prepare_sentence_tokens("this project is very interesting")
def getTracedRecursiveScriptModule():
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(),
trace_input)
script_module = traced_module._actual_script_module
export(script_module.forward)
annotate_args_decorator = annotate_args([
None,
([-1, -1], torch.int64, True),
])
annotate_args_decorator(script_module.forward)
return script_module
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
def BasicBertSequenceClassification_basic(module, tu: TestUtils):
module.forward(test_input)

View File

@ -20,6 +20,9 @@ mkdir -p $venv_dir
mkdir -p $serialized_test_dir mkdir -p $serialized_test_dir
python3 -m venv $venv_dir python3 -m venv $venv_dir
source $venv_dir/bin/activate source $venv_dir/bin/activate
# For bert_seq_classification
python3 -m pip install transformers
# For basic_mt
python3 -m pip install fairseq fvcore sacremoses subword-nmt python3 -m pip install fairseq fvcore sacremoses subword-nmt
cd "$torch_mlir_src_root" cd "$torch_mlir_src_root"

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
from . import basic_mt from . import basic_mt
from . import bert_seq_classification
def _get_argparse(): def _get_argparse():

View File

@ -563,6 +563,22 @@ def Torch_ConstantStrOp : Torch_Op<"constant.str",
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_ConstantDeviceOp : Torch_Op<"constant.device",
[NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant Device value.";
let description = [{
}];
let arguments = (ins
StrAttr:$value
);
let results = (outs
Torch_DeviceType:$result
);
let assemblyFormat = "$value attr-dict";
}
def Torch_ConstantIntOp : Torch_Op<"constant.int", def Torch_ConstantIntOp : Torch_Op<"constant.int",
[ConstantLike, NoSideEffect, [ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {

View File

@ -738,6 +738,15 @@ void ConstantStrOp::getAsmResultNames(
setNameFn(getResult(), "str"); setNameFn(getResult(), "str");
} }
//===----------------------------------------------------------------------===//
// ConstantDeviceOp
//===----------------------------------------------------------------------===//
void ConstantDeviceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), value());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConstantIntOp // ConstantIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -156,6 +156,19 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
toMlirNamedAttribute( toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s( "value", mlirStringAttrGet(context, toMlirStringRef(node->s(
c10::attr::value))))); c10::attr::value)))));
} else if (output->type()->cast<c10::TensorType>()) {
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
op = createMlirOperation(
"torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
toMlirNamedAttribute("value", attr));
} else if (output->type()->cast<c10::DeviceObjType>()) {
op = createMlirOperation(
"torch.constant.device", loc,
getMlirTypeFromTorchType(loc, output->type()),
toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
c10::attr::value)))));
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) { } else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
torch::jit::Function *function = functionType->function(); torch::jit::Function *function = functionType->function();
const std::string &symName = function->qualname().qualifiedName(); const std::string &symName = function->qualname().qualifiedName();