mirror of https://github.com/llvm/torch-mlir
Add BertSequenceClassification model to e2e
Use torch tracing to get the module because the original model is not TorchScriptable out of box.pull/345/head
parent
649d6e4f28
commit
89225b0cd8
|
@ -0,0 +1,70 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
# Basic BertForSequenceClassification program to classify the input sentence.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import BertForSequenceClassification, BertTokenizer
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
CLS = "[CLS]"
|
||||||
|
SEP = "[SEP]"
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_sentence_tokens(sentence: str):
|
||||||
|
return torch.tensor([tokenizer.encode(sentence)])
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBertSequenceClassification(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(
|
||||||
|
"bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
|
||||||
|
num_labels=
|
||||||
|
2, # The number of output labels--2 for binary classification.
|
||||||
|
output_attentions=
|
||||||
|
False, # Whether the model returns attentions weights.
|
||||||
|
output_hidden_states=
|
||||||
|
False, # Whether the model returns all hidden-states.
|
||||||
|
torchscript=True)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, tokens):
|
||||||
|
return self.model.forward(tokens)[0]
|
||||||
|
|
||||||
|
|
||||||
|
trace_input = {
|
||||||
|
'forward': _prepare_sentence_tokens("how do you like the project")
|
||||||
|
}
|
||||||
|
|
||||||
|
test_input = _prepare_sentence_tokens("this project is very interesting")
|
||||||
|
|
||||||
|
|
||||||
|
def getTracedRecursiveScriptModule():
|
||||||
|
traced_module = torch.jit.trace_module(BasicBertSequenceClassification(),
|
||||||
|
trace_input)
|
||||||
|
script_module = traced_module._actual_script_module
|
||||||
|
export(script_module.forward)
|
||||||
|
annotate_args_decorator = annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
annotate_args_decorator(script_module.forward)
|
||||||
|
return script_module
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: getTracedRecursiveScriptModule())
|
||||||
|
def BasicBertSequenceClassification_basic(module, tu: TestUtils):
|
||||||
|
module.forward(test_input)
|
|
@ -20,6 +20,9 @@ mkdir -p $venv_dir
|
||||||
mkdir -p $serialized_test_dir
|
mkdir -p $serialized_test_dir
|
||||||
python3 -m venv $venv_dir
|
python3 -m venv $venv_dir
|
||||||
source $venv_dir/bin/activate
|
source $venv_dir/bin/activate
|
||||||
|
# For bert_seq_classification
|
||||||
|
python3 -m pip install transformers
|
||||||
|
# For basic_mt
|
||||||
python3 -m pip install fairseq fvcore sacremoses subword-nmt
|
python3 -m pip install fairseq fvcore sacremoses subword-nmt
|
||||||
|
|
||||||
cd "$torch_mlir_src_root"
|
cd "$torch_mlir_src_root"
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.torchscript.framework import SerializableTest, generate
|
||||||
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
from torch_mlir_e2e_test.torchscript.annotations import extract_serializable_annotations
|
||||||
|
|
||||||
from . import basic_mt
|
from . import basic_mt
|
||||||
|
from . import bert_seq_classification
|
||||||
|
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
|
|
|
@ -563,6 +563,22 @@ def Torch_ConstantStrOp : Torch_Op<"constant.str",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_ConstantDeviceOp : Torch_Op<"constant.device",
|
||||||
|
[NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||||
|
let summary = "Materialize a constant Device value.";
|
||||||
|
let description = [{
|
||||||
|
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
StrAttr:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_DeviceType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$value attr-dict";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
||||||
[ConstantLike, NoSideEffect,
|
[ConstantLike, NoSideEffect,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||||
|
|
|
@ -738,6 +738,15 @@ void ConstantStrOp::getAsmResultNames(
|
||||||
setNameFn(getResult(), "str");
|
setNameFn(getResult(), "str");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantDeviceOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void ConstantDeviceOp::getAsmResultNames(
|
||||||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||||||
|
setNameFn(getResult(), value());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConstantIntOp
|
// ConstantIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -156,6 +156,19 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||||
c10::attr::value)))));
|
c10::attr::value)))));
|
||||||
|
} else if (output->type()->cast<c10::TensorType>()) {
|
||||||
|
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||||
|
op = createMlirOperation(
|
||||||
|
"torch.tensor.literal", loc,
|
||||||
|
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||||
|
toMlirNamedAttribute("value", attr));
|
||||||
|
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
||||||
|
op = createMlirOperation(
|
||||||
|
"torch.constant.device", loc,
|
||||||
|
getMlirTypeFromTorchType(loc, output->type()),
|
||||||
|
toMlirNamedAttribute(
|
||||||
|
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||||
|
c10::attr::value)))));
|
||||||
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
} else if (auto functionType = output->type()->cast<c10::FunctionType>()) {
|
||||||
torch::jit::Function *function = functionType->function();
|
torch::jit::Function *function = functionType->function();
|
||||||
const std::string &symName = function->qualname().qualifiedName();
|
const std::string &symName = function->qualname().qualifiedName();
|
||||||
|
|
Loading…
Reference in New Issue