torch-mlir/examples/ltc_backend_bert.py

161 lines
5.1 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Runs a training of the Bert model using the Lazy Tensor Core with the
example Torch MLIR backend.
Most of the code in this example was copied from the wonderful tutorial
https://huggingface.co/transformers/training.html#fine-tuning-in-native-pytorch
Based on LTC code samples by ramiro050
https://github.com/ramiro050/lazy-tensor-samples
"""
import argparse
import sys
from typing import List
import torch
import torch._C
import torch._lazy
from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, \
BertConfig, BertTokenizer, AdamW, get_scheduler
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length",
truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
tokenized_datasets.set_format('torch')
return tokenized_datasets
def train(model: BertForSequenceClassification,
num_epochs: int,
num_training_steps: int,
train_dataloader: DataLoader,
device: torch.device) -> List[torch.Tensor]:
optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps)
model.train()
losses = []
for _ in range(num_epochs):
for batch in train_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
losses.append(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if 'lazy' in str(model.device):
print("Calling Mark Step")
torch._lazy.mark_step()
return losses
def main(device='lazy', full_size=False):
"""
Load model to specified device. Ensure that any backends have been initialized by this point.
:param device: name of device to load tensors to
:param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant
"""
torch.manual_seed(0)
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
.select(range(2))
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
batch_size=8)
if full_size:
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
num_labels=2)
else:
configuration = BertConfig(
vocab_size=28996,
hidden_size=32,
num_hidden_layers=1,
num_attention_heads=2,
intermediate_size=32,
hidden_act='gelu',
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
layer_norm_eps=1.0e-05,
)
model = BertForSequenceClassification(configuration)
model.to(device)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
# Get debug information from LTC
if 'ltc_backend' in sys.modules:
computation = ltc_backend.get_latest_computation()
if computation:
print(computation.debug_string())
print('Loss: ', losses)
return model, losses
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--device",
type=str.upper,
choices=["CPU", "TS", "MLIR_EXAMPLE"],
default="MLIR_EXAMPLE",
help="The device type",
)
parser.add_argument(
"-f",
"--full_size",
action='store_true',
default=False,
help="Use full sized BERT model instead of one with smaller parameterization",
)
args = parser.parse_args()
if args.device in ("TS", "MLIR_EXAMPLE"):
if args.device == "TS":
import torch._lazy.ts_backend
torch._lazy.ts_backend.init()
elif args.device == "MLIR_EXAMPLE":
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
ltc_backend._initialize()
device = "lazy"
print("Initialized backend")
else:
device = args.device.lower()
main(device, args.full_size)