mirror of https://github.com/llvm/torch-mlir
Bert example and relevant shape inference functions (#831)
parent
406d1e7538
commit
de5b380143
|
@ -0,0 +1,135 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
"""
|
||||||
|
Runs a training of the Bert model using the Lazy Tensor Core with the
|
||||||
|
example Torch MLIR backend.
|
||||||
|
|
||||||
|
Most of the code in this example was copied from the wonderful tutorial
|
||||||
|
https://huggingface.co/transformers/training.html#fine-tuning-in-native-pytorch
|
||||||
|
|
||||||
|
Based on LTC code samples by ramiro050
|
||||||
|
https://github.com/ramiro050/lazy-tensor-samples
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from datasets.dataset_dict import DatasetDict
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import BertForSequenceClassification, \
|
||||||
|
BertTokenizer, AdamW, get_scheduler
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
||||||
|
|
||||||
|
def tokenize_function(examples):
|
||||||
|
return tokenizer(examples["text"], padding="max_length",
|
||||||
|
truncation=True)
|
||||||
|
|
||||||
|
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
||||||
|
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
|
||||||
|
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
|
||||||
|
tokenized_datasets.set_format('torch')
|
||||||
|
|
||||||
|
return tokenized_datasets
|
||||||
|
|
||||||
|
|
||||||
|
def train(model: BertForSequenceClassification,
|
||||||
|
num_epochs: int,
|
||||||
|
num_training_steps: int,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
do_mark_step: bool) -> List[torch.Tensor]:
|
||||||
|
optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||||
|
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
|
||||||
|
num_warmup_steps=0,
|
||||||
|
num_training_steps=num_training_steps)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
losses = []
|
||||||
|
for _ in range(num_epochs):
|
||||||
|
for batch in train_dataloader:
|
||||||
|
batch = {k: v.to(device) for k, v in batch.items()}
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs.loss
|
||||||
|
loss.backward()
|
||||||
|
losses.append(loss)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if do_mark_step and 'lazy' in str(model.device):
|
||||||
|
print("Calling Mark Step")
|
||||||
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
|
||||||
|
def main(device, lower_only):
|
||||||
|
if device in ("TS", "MLIR_EXAMPLE"):
|
||||||
|
import torch._lazy
|
||||||
|
|
||||||
|
if device == "TS":
|
||||||
|
import torch._lazy.ts_backend
|
||||||
|
|
||||||
|
torch._lazy.ts_backend.init()
|
||||||
|
|
||||||
|
elif device == "MLIR_EXAMPLE":
|
||||||
|
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
|
||||||
|
|
||||||
|
ltc_backend._initialize()
|
||||||
|
|
||||||
|
device = "lazy"
|
||||||
|
print("Initialized backend")
|
||||||
|
else:
|
||||||
|
device = device.lower()
|
||||||
|
|
||||||
|
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
|
||||||
|
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
|
||||||
|
.select(range(2))
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
|
||||||
|
batch_size=8)
|
||||||
|
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
|
||||||
|
num_labels=2)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
num_epochs = 3
|
||||||
|
num_training_steps = num_epochs * len(train_dataloader)
|
||||||
|
losses = train(model, num_epochs,
|
||||||
|
num_training_steps, train_dataloader, device, not lower_only)
|
||||||
|
|
||||||
|
if lower_only:
|
||||||
|
print('\nJIT Graph:')
|
||||||
|
import torch._C
|
||||||
|
graph_str = torch._C._lazy._get_tensors_backend([losses[0]])
|
||||||
|
print(graph_str)
|
||||||
|
else:
|
||||||
|
# Execute computation
|
||||||
|
print('Loss: ', losses)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--device",
|
||||||
|
type=str.upper,
|
||||||
|
choices=["CPU", "TS", "MLIR_EXAMPLE"],
|
||||||
|
default="MLIR_EXAMPLE",
|
||||||
|
help="The device type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--lower_only",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Only get backend printout -- do not execute computation",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args.device, args.lower_only)
|
|
@ -9,10 +9,127 @@
|
||||||
|
|
||||||
#include "LazyShapeInference.h"
|
#include "LazyShapeInference.h"
|
||||||
#include "../utils/exception.h"
|
#include "../utils/exception.h"
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
|
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
|
||||||
|
|
||||||
|
// Turns any negative index positive (assuming it's valid)
|
||||||
|
int64_t normalize_index(int64_t index, unsigned dims) {
|
||||||
|
return index < 0 ? (int64_t)dims + index : index;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape>
|
||||||
|
compute_shape_dropout(const at::Tensor& input, double p, bool train) {
|
||||||
|
return {Shape(input.scalar_type(), input.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> compute_shape_layer_norm(
|
||||||
|
const at::Tensor& input, at::IntArrayRef normalized_shape,
|
||||||
|
const c10::optional<at::Tensor>& weight,
|
||||||
|
const c10::optional<at::Tensor>& bias, double eps, bool cudnn_enable) {
|
||||||
|
return {Shape(input.scalar_type(), input.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape>
|
||||||
|
compute_shape_matmul(const at::Tensor& self, const at::Tensor& other) {
|
||||||
|
std::vector<int64_t> sizes;
|
||||||
|
|
||||||
|
auto self_sizes = self.sizes().vec();
|
||||||
|
auto other_sizes = other.sizes().vec();
|
||||||
|
|
||||||
|
// For tensors with dimensions >2, the leading dimensions are for batch info.
|
||||||
|
// The last 2 (or 1 in the case of a single dim tensor) dimensions are the
|
||||||
|
// matrix dimensions themselves, which is checked to ensure the matmul op
|
||||||
|
// is legal.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// [1, 2, 3, 4] -> [1, 2] batch dims and [3, 4] matrix
|
||||||
|
// [1, 4, 5] -> [1] batch dims and [4, 5] matrix
|
||||||
|
// [4, 5] -> [] batch dims and [4, 5] matrix
|
||||||
|
// [5] -> [] batch dims and [5] matrix
|
||||||
|
//
|
||||||
|
// We'll start by splitting the shapes as described above.
|
||||||
|
auto partition_shape = [](at::ArrayRef<int64_t> sizes) {
|
||||||
|
if (sizes.size() <= 2) {
|
||||||
|
return std::make_pair(
|
||||||
|
std::vector<int64_t>(),
|
||||||
|
std::vector<int64_t>(sizes.begin(), sizes.end()));
|
||||||
|
} else {
|
||||||
|
std::size_t partition_idx = sizes.size() - 2;
|
||||||
|
return std::make_pair(
|
||||||
|
std::vector<int64_t>(sizes.begin(), sizes.begin() + partition_idx),
|
||||||
|
std::vector<int64_t>(sizes.begin() + partition_idx, sizes.end()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto [self_batch_sizes, self_matrix_sizes] = partition_shape(self_sizes);
|
||||||
|
auto [other_batch_sizes, other_matrix_sizes] = partition_shape(other_sizes);
|
||||||
|
|
||||||
|
// Insert batch dimensions.
|
||||||
|
// The final list of sizes will be based on the tensor w/ more dims.
|
||||||
|
// Individual dimension sizes are "right justified" as we iterate thru
|
||||||
|
// to pick the larger dimension between them.
|
||||||
|
// 0 1 1 3 4
|
||||||
|
// 5 1 2
|
||||||
|
// ---------
|
||||||
|
// 0 1 5 3 4 <- Result
|
||||||
|
int64_t self_size, other_size;
|
||||||
|
std::size_t num_batch_dim =
|
||||||
|
std::max(self_batch_sizes.size(), other_batch_sizes.size());
|
||||||
|
auto get_batch_dim = [&](std::vector<int64_t> batch_sizes, std::size_t dim) {
|
||||||
|
long idx = dim - num_batch_dim + batch_sizes.size();
|
||||||
|
// Negative index means out of bounds, which defaults to a dim size of 1.
|
||||||
|
return idx < 0 ? 1 : batch_sizes[idx];
|
||||||
|
};
|
||||||
|
for (std::size_t i = 0; i < num_batch_dim; i++) {
|
||||||
|
self_size = get_batch_dim(self_batch_sizes, i);
|
||||||
|
other_size = get_batch_dim(other_batch_sizes, i);
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
self_size == 1 || other_size == 1 || self_size == other_size,
|
||||||
|
"At trailing dimension ", i, ", expected for dimensions ",
|
||||||
|
"to either match or have one of them equal one, but got ", self_size,
|
||||||
|
" and ", other_size, " instead!");
|
||||||
|
|
||||||
|
sizes.push_back(std::max(self_size, other_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep track of the inner dimensions of matmul to validate op is valid.
|
||||||
|
std::pair<int64_t, int64_t> inner_sizes;
|
||||||
|
if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 1) {
|
||||||
|
// Dot-Product -- scalar output, so no dimensions inserted
|
||||||
|
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
|
||||||
|
} else if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 2) {
|
||||||
|
// Vector-Matrix product (m) @ (m, n) -> (n)
|
||||||
|
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
|
||||||
|
|
||||||
|
sizes.push_back(other_matrix_sizes[1]);
|
||||||
|
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 1) {
|
||||||
|
// Matrix-Vector product (m, n) @ (n) -> (m)
|
||||||
|
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
|
||||||
|
|
||||||
|
sizes.push_back(self_matrix_sizes[0]);
|
||||||
|
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 2) {
|
||||||
|
// Matrix-Matrix product (m, n) @ (n, o) -> (m, o)
|
||||||
|
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
|
||||||
|
|
||||||
|
sizes.push_back(self_matrix_sizes[0]);
|
||||||
|
sizes.push_back(other_matrix_sizes[1]);
|
||||||
|
} else {
|
||||||
|
// By this time, self_matrix_sizes and other_matrix_sizes should have at
|
||||||
|
// most 2 dims, so if this is executed something has gone wrong...
|
||||||
|
TORCH_CHECK(false, "Invalid matmul shape combination!");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
inner_sizes.first == inner_sizes.second, "Inner dimension of matrix (",
|
||||||
|
inner_sizes.first, ") does not ", "match (", inner_sizes.second, ")!");
|
||||||
|
|
||||||
|
return {Shape(self.scalar_type(), sizes)};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Shape> compute_shape_native_batch_norm(
|
std::vector<Shape> compute_shape_native_batch_norm(
|
||||||
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
||||||
const c10::optional<at::Tensor>& bias,
|
const c10::optional<at::Tensor>& bias,
|
||||||
|
@ -33,5 +150,129 @@ std::vector<Shape> compute_shape_native_batch_norm(
|
||||||
return shapes;
|
return shapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape>
|
||||||
|
compute_shape_reshape(const at::Tensor& self, at::IntArrayRef shape) {
|
||||||
|
// Make a copy of the desired output shape.
|
||||||
|
std::vector<int64_t> sizes(shape.begin(), shape.end());
|
||||||
|
|
||||||
|
// Product of all sizes in input shape is the number of entries in tensor.
|
||||||
|
int64_t num_entries = 1;
|
||||||
|
for (int64_t i : self.sizes().vec()) {
|
||||||
|
num_entries *= i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the number of entries in the desired shape. If there is a wildcard
|
||||||
|
// dimension, we need to find it now in order to populate it.
|
||||||
|
long wildcard_idx = -1;
|
||||||
|
int64_t num_concrete_entries = 1;
|
||||||
|
for (std::size_t idx = 0; idx < sizes.size(); idx++) {
|
||||||
|
if (sizes[idx] != -1) {
|
||||||
|
num_concrete_entries *= sizes[idx];
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(wildcard_idx == -1, "only one dimension can be inferred");
|
||||||
|
wildcard_idx = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wildcard_idx == -1) {
|
||||||
|
// No wildcard, the shape should already be known.
|
||||||
|
TORCH_CHECK(
|
||||||
|
num_entries == num_concrete_entries, "shape `[", sizes,
|
||||||
|
"]` is invalid for input of size ", num_concrete_entries);
|
||||||
|
} else {
|
||||||
|
// There is one dimension which is not explicitly declared -- we need to
|
||||||
|
// infer.
|
||||||
|
TORCH_CHECK(
|
||||||
|
num_entries % num_concrete_entries == 0, "shape `[", sizes,
|
||||||
|
"]` is invalid for input of size ", num_concrete_entries);
|
||||||
|
|
||||||
|
sizes[wildcard_idx] = num_entries / num_concrete_entries;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {Shape(self.scalar_type(), sizes)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> compute_shape_rsub(
|
||||||
|
const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) {
|
||||||
|
// Since other is scalar, the result will match tensor shape.
|
||||||
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape>
|
||||||
|
compute_shape_select(const at::Tensor& self, int64_t dim, int64_t index) {
|
||||||
|
auto original_shape = self.sizes().vec();
|
||||||
|
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
dim < (int64_t)sizes.size(), "Dimension ", dim,
|
||||||
|
" is out of bounds for tensor with ", sizes.size(), " dimensions!");
|
||||||
|
TORCH_CHECK(
|
||||||
|
index < sizes[dim], "Index ", index,
|
||||||
|
" is out of bounds for dimension of size ", sizes[dim]);
|
||||||
|
sizes.erase(sizes.begin() + dim);
|
||||||
|
|
||||||
|
return {Shape(self.scalar_type(), sizes)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> compute_shape_slice(
|
||||||
|
const at::Tensor& self, int64_t dim, c10::optional<int64_t> start,
|
||||||
|
c10::optional<int64_t> end, int64_t step) {
|
||||||
|
auto original_shape = self.sizes().vec();
|
||||||
|
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
|
||||||
|
|
||||||
|
int64_t dim_size = sizes[dim];
|
||||||
|
|
||||||
|
// Index may be negative, so we must normalize it.
|
||||||
|
int64_t start_norm = normalize_index(start.value(), dim_size);
|
||||||
|
int64_t end_norm = normalize_index(end.value(), dim_size);
|
||||||
|
|
||||||
|
if (start_norm >= end_norm || start_norm >= dim_size || end_norm <= 0) {
|
||||||
|
// Slice is out of bounds, nothing in range.
|
||||||
|
sizes[dim] = 0;
|
||||||
|
} else {
|
||||||
|
// Clamp upper and lower bound to valid indices.
|
||||||
|
start_norm = std::max((int64_t)0, start_norm);
|
||||||
|
end_norm = std::min(dim_size, end_norm);
|
||||||
|
|
||||||
|
// Final size is determined by step and interval size.
|
||||||
|
sizes[dim] = std::ceil((double)(end_norm - start_norm) / (double)step);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {Shape(self.scalar_type(), sizes)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> compute_shape_softmax(
|
||||||
|
const at::Tensor& self, int64_t dim, c10::optional<at::ScalarType> dtype) {
|
||||||
|
if (dtype.has_value()) {
|
||||||
|
return {Shape(dtype.value(), self.sizes().vec())};
|
||||||
|
}
|
||||||
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape>
|
||||||
|
compute_shape_transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) {
|
||||||
|
auto original_shape = self.sizes().vec();
|
||||||
|
std::vector<int64_t> sizes{original_shape.begin(), original_shape.end()};
|
||||||
|
|
||||||
|
// Index may be negative, so we must normalize it. We create new variables
|
||||||
|
// instead of replacing the existing ones so that in the case of an error,
|
||||||
|
// the original values can be printed out.
|
||||||
|
int64_t dim0_norm = normalize_index(dim0, sizes.size());
|
||||||
|
int64_t dim1_norm = normalize_index(dim1, sizes.size());
|
||||||
|
|
||||||
|
// Verify dimensions are valid.
|
||||||
|
TORCH_CHECK(
|
||||||
|
0 <= dim0_norm && dim0_norm < (int64_t)sizes.size(), "dim0 has value ",
|
||||||
|
dim0, ", but there are only ", sizes.size(), " tensor dimensions");
|
||||||
|
TORCH_CHECK(
|
||||||
|
0 <= dim1_norm && dim1_norm < (int64_t)sizes.size(), "dim1 has value ",
|
||||||
|
dim1, ", but there are only ", sizes.size(), " tensor dimensions");
|
||||||
|
|
||||||
|
// Swap shapes at dimensions.
|
||||||
|
std::swap(sizes[dim0_norm], sizes[dim1_norm]);
|
||||||
|
|
||||||
|
return {Shape(self.scalar_type(), sizes)};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
Loading…
Reference in New Issue