2022-05-10 21:03:41 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
"""
|
|
|
|
Runs a training of the Bert model using the Lazy Tensor Core with the
|
|
|
|
example Torch MLIR backend.
|
|
|
|
|
|
|
|
Most of the code in this example was copied from the wonderful tutorial
|
|
|
|
https://huggingface.co/transformers/training.html#fine-tuning-in-native-pytorch
|
|
|
|
|
|
|
|
Based on LTC code samples by ramiro050
|
|
|
|
https://github.com/ramiro050/lazy-tensor-samples
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
2022-06-10 03:56:01 +08:00
|
|
|
import sys
|
|
|
|
from typing import List
|
|
|
|
|
2022-05-10 21:03:41 +08:00
|
|
|
import torch
|
2022-06-10 03:56:01 +08:00
|
|
|
import torch._C
|
|
|
|
import torch._lazy
|
2022-05-10 21:03:41 +08:00
|
|
|
from datasets import load_dataset
|
|
|
|
from datasets.dataset_dict import DatasetDict
|
|
|
|
from torch.utils.data import DataLoader
|
2024-04-28 05:16:31 +08:00
|
|
|
from transformers import (
|
|
|
|
BertForSequenceClassification,
|
|
|
|
BertConfig,
|
|
|
|
BertTokenizer,
|
|
|
|
AdamW,
|
|
|
|
get_scheduler,
|
|
|
|
)
|
2022-05-10 21:03:41 +08:00
|
|
|
|
|
|
|
|
|
|
|
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
|
2024-04-28 05:16:31 +08:00
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
2022-05-10 21:03:41 +08:00
|
|
|
|
|
|
|
def tokenize_function(examples):
|
2024-04-28 05:16:31 +08:00
|
|
|
return tokenizer(examples["text"], padding="max_length", truncation=True)
|
2022-05-10 21:03:41 +08:00
|
|
|
|
|
|
|
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
2024-04-28 05:16:31 +08:00
|
|
|
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
|
|
|
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
|
|
|
tokenized_datasets.set_format("torch")
|
2022-05-10 21:03:41 +08:00
|
|
|
|
|
|
|
return tokenized_datasets
|
|
|
|
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
def train(
|
|
|
|
model: BertForSequenceClassification,
|
|
|
|
num_epochs: int,
|
|
|
|
num_training_steps: int,
|
|
|
|
train_dataloader: DataLoader,
|
|
|
|
device: torch.device,
|
|
|
|
) -> List[torch.Tensor]:
|
2022-05-10 21:03:41 +08:00
|
|
|
optimizer = AdamW(model.parameters(), lr=5e-5)
|
2024-04-28 05:16:31 +08:00
|
|
|
lr_scheduler = get_scheduler(
|
|
|
|
"linear",
|
|
|
|
optimizer=optimizer,
|
|
|
|
num_warmup_steps=0,
|
|
|
|
num_training_steps=num_training_steps,
|
|
|
|
)
|
2022-05-10 21:03:41 +08:00
|
|
|
|
|
|
|
model.train()
|
|
|
|
losses = []
|
|
|
|
for _ in range(num_epochs):
|
|
|
|
for batch in train_dataloader:
|
|
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
outputs = model(**batch)
|
|
|
|
loss = outputs.loss
|
|
|
|
loss.backward()
|
|
|
|
losses.append(loss)
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
if "lazy" in str(model.device):
|
2022-05-10 21:03:41 +08:00
|
|
|
print("Calling Mark Step")
|
|
|
|
torch._lazy.mark_step()
|
|
|
|
|
|
|
|
return losses
|
|
|
|
|
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
def main(device="lazy", full_size=False):
|
2022-06-10 03:56:01 +08:00
|
|
|
"""
|
|
|
|
Load model to specified device. Ensure that any backends have been initialized by this point.
|
2022-05-10 21:03:41 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
:param device: name of device to load tensors to
|
|
|
|
:param full_size: if true, use a full pretrained bert-base-cased model instead of a smaller variant
|
|
|
|
"""
|
|
|
|
torch.manual_seed(0)
|
2022-05-10 21:03:41 +08:00
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
tokenized_datasets = tokenize_dataset(load_dataset("imdb"))
|
|
|
|
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(2))
|
2022-05-10 21:03:41 +08:00
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
|
2022-06-08 02:38:50 +08:00
|
|
|
if full_size:
|
2024-04-28 05:16:31 +08:00
|
|
|
model = BertForSequenceClassification.from_pretrained(
|
|
|
|
"bert-base-cased", num_labels=2
|
|
|
|
)
|
2022-06-08 02:38:50 +08:00
|
|
|
else:
|
|
|
|
configuration = BertConfig(
|
|
|
|
vocab_size=28996,
|
|
|
|
hidden_size=32,
|
|
|
|
num_hidden_layers=1,
|
|
|
|
num_attention_heads=2,
|
|
|
|
intermediate_size=32,
|
2024-04-28 05:16:31 +08:00
|
|
|
hidden_act="gelu",
|
2022-06-08 02:38:50 +08:00
|
|
|
hidden_dropout_prob=0.0,
|
|
|
|
attention_probs_dropout_prob=0.0,
|
|
|
|
max_position_embeddings=512,
|
|
|
|
layer_norm_eps=1.0e-05,
|
|
|
|
)
|
|
|
|
model = BertForSequenceClassification(configuration)
|
|
|
|
|
2022-05-10 21:03:41 +08:00
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
num_epochs = 3
|
|
|
|
num_training_steps = num_epochs * len(train_dataloader)
|
2022-06-10 03:56:01 +08:00
|
|
|
losses = train(model, num_epochs, num_training_steps, train_dataloader, device)
|
2022-05-10 21:03:41 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
# Get debug information from LTC
|
2024-04-28 05:16:31 +08:00
|
|
|
if "torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND" in sys.modules:
|
2022-07-13 03:56:52 +08:00
|
|
|
computation = lazy_backend.get_latest_computation()
|
2022-06-10 03:56:01 +08:00
|
|
|
if computation:
|
|
|
|
print(computation.debug_string())
|
2022-05-10 21:03:41 +08:00
|
|
|
|
2024-04-28 05:16:31 +08:00
|
|
|
print("Loss: ", losses)
|
2022-06-08 02:38:50 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
return model, losses
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-05-10 21:03:41 +08:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
"-d",
|
|
|
|
"--device",
|
|
|
|
type=str.upper,
|
|
|
|
choices=["CPU", "TS", "MLIR_EXAMPLE"],
|
|
|
|
default="MLIR_EXAMPLE",
|
|
|
|
help="The device type",
|
|
|
|
)
|
2022-06-08 02:38:50 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"-f",
|
|
|
|
"--full_size",
|
2024-04-28 05:16:31 +08:00
|
|
|
action="store_true",
|
2022-06-08 02:38:50 +08:00
|
|
|
default=False,
|
|
|
|
help="Use full sized BERT model instead of one with smaller parameterization",
|
|
|
|
)
|
2022-05-10 21:03:41 +08:00
|
|
|
args = parser.parse_args()
|
2022-06-10 03:56:01 +08:00
|
|
|
|
|
|
|
if args.device in ("TS", "MLIR_EXAMPLE"):
|
|
|
|
if args.device == "TS":
|
2022-06-15 00:09:55 +08:00
|
|
|
import torch._lazy.ts_backend
|
2024-04-28 05:16:31 +08:00
|
|
|
|
2022-06-10 03:56:01 +08:00
|
|
|
torch._lazy.ts_backend.init()
|
|
|
|
|
|
|
|
elif args.device == "MLIR_EXAMPLE":
|
2022-08-26 06:25:01 +08:00
|
|
|
import torch_mlir._mlir_libs._REFERENCE_LAZY_BACKEND as lazy_backend
|
2022-07-01 03:19:05 +08:00
|
|
|
|
2022-07-13 03:56:52 +08:00
|
|
|
lazy_backend._initialize()
|
2022-06-10 03:56:01 +08:00
|
|
|
|
|
|
|
device = "lazy"
|
|
|
|
print("Initialized backend")
|
|
|
|
else:
|
|
|
|
device = args.device.lower()
|
|
|
|
|
|
|
|
main(device, args.full_size)
|