mirror of https://github.com/llvm/torch-mlir
Bert example and relevant shape inference functions (#831)
@ -0,0 +1,135 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
Runs a training of the Bert model using the Lazy Tensor Core with the
example Torch MLIR backend.
Most of the code in this example was copied from the wonderful tutorial
Based on LTC code samples by ramiro050
import argparse
import torch
from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, \
BertTokenizer, AdamW, get_scheduler
from typing import List
def tokenize_dataset(dataset: DatasetDict) -> DatasetDict:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length",
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
return tokenized_datasets
def train(model: BertForSequenceClassification,
num_epochs: int,
num_training_steps: int,
train_dataloader: DataLoader,
device: torch.device,
do_mark_step: bool) -> List[torch.Tensor]:
optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler('linear', optimizer=optimizer,
losses = []
for _ in range(num_epochs):
for batch in train_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
if do_mark_step and 'lazy' in str(model.device):
print("Calling Mark Step")
return losses
def main(device, lower_only):
if device in ("TS", "MLIR_EXAMPLE"):
import torch._lazy
if device == "TS":
import torch._lazy.ts_backend
elif device == "MLIR_EXAMPLE":
import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend
device = "lazy"
print("Initialized backend")
device = device.lower()
tokenized_datasets = tokenize_dataset(load_dataset('imdb'))
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \
train_dataloader = DataLoader(small_train_dataset, shuffle=True,
model = BertForSequenceClassification.from_pretrained('bert-base-cased',
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
losses = train(model, num_epochs,
num_training_steps, train_dataloader, device, not lower_only)
if lower_only:
print('\nJIT Graph:')
import torch._C
graph_str = torch._C._lazy._get_tensors_backend([losses[0]])
# Execute computation
print('Loss: ', losses)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
choices=["CPU", "TS", "MLIR_EXAMPLE"],
help="The device type",
help="Only get backend printout -- do not execute computation",
args = parser.parse_args()
main(args.device, args.lower_only)
@ -9,10 +9,127 @@
#include "LazyShapeInference.h"
#include "../utils/exception.h"
#include <cmath>
namespace torch {
namespace lazy {
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
// Turns any negative index positive (assuming it's valid)
int64_t normalize_index(int64_t index, unsigned dims) {
return index < 0 ? (int64_t)dims + index : index;
compute_shape_dropout(const at::Tensor& input, double p, bool train) {
return {Shape(input.scalar_type(), input.sizes().vec())};
std::vector<Shape> compute_shape_layer_norm(
const at::Tensor& input, at::IntArrayRef normalized_shape,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, double eps, bool cudnn_enable) {
return {Shape(input.scalar_type(), input.sizes().vec())};
compute_shape_matmul(const at::Tensor& self, const at::Tensor& other) {
std::vector<int64_t> sizes;
auto self_sizes = self.sizes().vec();
auto other_sizes = other.sizes().vec();
// For tensors with dimensions >2, the leading dimensions are for batch info.
// The last 2 (or 1 in the case of a single dim tensor) dimensions are the
// matrix dimensions themselves, which is checked to ensure the matmul op
// is legal.
// Example:
// [1, 2, 3, 4] -> [1, 2] batch dims and [3, 4] matrix
// [1, 4, 5] -> [1] batch dims and [4, 5] matrix
// [4, 5] -> [] batch dims and [4, 5] matrix
// [5] -> [] batch dims and [5] matrix
// We'll start by splitting the shapes as described above.
auto partition_shape = [](at::ArrayRef<int64_t> sizes) {
if (sizes.size() <= 2) {
return std::make_pair(
std::vector<int64_t>(sizes.begin(), sizes.end()));
} else {
std::size_t partition_idx = sizes.size() - 2;
return std::make_pair(
std::vector<int64_t>(sizes.begin(), sizes.begin() + partition_idx),
std::vector<int64_t>(sizes.begin() + partition_idx, sizes.end()));
auto [self_batch_sizes, self_matrix_sizes] = partition_shape(self_sizes);
auto [other_batch_sizes, other_matrix_sizes] = partition_shape(other_sizes);
// Insert batch dimensions.
// The final list of sizes will be based on the tensor w/ more dims.
// Individual dimension sizes are "right justified" as we iterate thru
// to pick the larger dimension between them.
// 0 1 1 3 4
// 5 1 2
// ---------
// 0 1 5 3 4 <- Result
int64_t self_size, other_size;
std::size_t num_batch_dim =
std::max(self_batch_sizes.size(), other_batch_sizes.size());
auto get_batch_dim = [&](std::vector<int64_t> batch_sizes, std::size_t dim) {
long idx = dim - num_batch_dim + batch_sizes.size();
// Negative index means out of bounds, which defaults to a dim size of 1.
return idx < 0 ? 1 : batch_sizes[idx];
for (std::size_t i = 0; i < num_batch_dim; i++) {
self_size = get_batch_dim(self_batch_sizes, i);
other_size = get_batch_dim(other_batch_sizes, i);
self_size == 1 || other_size == 1 || self_size == other_size,
"At trailing dimension ", i, ", expected for dimensions ",
"to either match or have one of them equal one, but got ", self_size,
" and ", other_size, " instead!");
sizes.push_back(std::max(self_size, other_size));
// Keep track of the inner dimensions of matmul to validate op is valid.
std::pair<int64_t, int64_t> inner_sizes;
if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 1) {
// Dot-Product -- scalar output, so no dimensions inserted
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
} else if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 2) {
// Vector-Matrix product (m) @ (m, n) -> (n)
inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]);
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 1) {
// Matrix-Vector product (m, n) @ (n) -> (m)
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
} else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 2) {
// Matrix-Matrix product (m, n) @ (n, o) -> (m, o)
inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]);
} else {
// By this time, self_matrix_sizes and other_matrix_sizes should have at
// most 2 dims, so if this is executed something has gone wrong...
TORCH_CHECK(false, "Invalid matmul shape combination!");
inner_sizes.first == inner_sizes.second, "Inner dimension of matrix (",
inner_sizes.first, ") does not ", "match (", inner_sizes.second, ")!");
return {Shape(self.scalar_type(), sizes)};
std::vector<Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
@ -33,5 +150,129 @@ std::vector<Shape> compute_shape_native_batch_norm(
return shapes;
compute_shape_reshape(const at::Tensor& self, at::IntArrayRef shape) {
// Make a copy of the desired output shape.
std::vector<int64_t> sizes(shape.begin(), shape.end());
// Product of all sizes in input shape is the number of entries in tensor.
int64_t num_entries = 1;
for (int64_t i : self.sizes().vec()) {
num_entries *= i;
// Validate the number of entries in the desired shape. If there is a wildcard
// dimension, we need to find it now in order to populate it.
long wildcard_idx = -1;
int64_t num_concrete_entries = 1;
for (std::size_t idx = 0; idx < sizes.size(); idx++) {
if (sizes[idx] != -1) {
num_concrete_entries *= sizes[idx];
} else {
TORCH_CHECK(wildcard_idx == -1, "only one dimension can be inferred");
wildcard_idx = idx;
if (wildcard_idx == -1) {
// No wildcard, the shape should already be known.
num_entries == num_concrete_entries, "shape `[", sizes,
"]` is invalid for input of size ", num_concrete_entries);
} else {
// There is one dimension which is not explicitly declared -- we need to
// infer.
num_entries % num_concrete_entries == 0, "shape `[", sizes,
"]` is invalid for input of size ", num_concrete_entries);
sizes[wildcard_idx] = num_entries / num_concrete_entries;
return {Shape(self.scalar_type(), sizes)};
std::vector<Shape> compute_shape_rsub(
const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) {
// Since other is scalar, the result will match tensor shape.
return {Shape(self.scalar_type(), self.sizes().vec())};
compute_shape_select(const at::Tensor& self, int64_t dim, int64_t index) {
auto original_shape = self.sizes().vec();
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
dim < (int64_t)sizes.size(), "Dimension ", dim,
" is out of bounds for tensor with ", sizes.size(), " dimensions!");
index < sizes[dim], "Index ", index,
" is out of bounds for dimension of size ", sizes[dim]);
sizes.erase(sizes.begin() + dim);
return {Shape(self.scalar_type(), sizes)};
std::vector<Shape> compute_shape_slice(
const at::Tensor& self, int64_t dim, c10::optional<int64_t> start,
c10::optional<int64_t> end, int64_t step) {
auto original_shape = self.sizes().vec();
std::vector<int64_t> sizes(original_shape.begin(), original_shape.end());
int64_t dim_size = sizes[dim];
// Index may be negative, so we must normalize it.
int64_t start_norm = normalize_index(start.value(), dim_size);
int64_t end_norm = normalize_index(end.value(), dim_size);
if (start_norm >= end_norm || start_norm >= dim_size || end_norm <= 0) {
// Slice is out of bounds, nothing in range.
sizes[dim] = 0;
} else {
// Clamp upper and lower bound to valid indices.
start_norm = std::max((int64_t)0, start_norm);
end_norm = std::min(dim_size, end_norm);
// Final size is determined by step and interval size.
sizes[dim] = std::ceil((double)(end_norm - start_norm) / (double)step);
return {Shape(self.scalar_type(), sizes)};
std::vector<Shape> compute_shape_softmax(
const at::Tensor& self, int64_t dim, c10::optional<at::ScalarType> dtype) {
if (dtype.has_value()) {
return {Shape(dtype.value(), self.sizes().vec())};
return {Shape(self.scalar_type(), self.sizes().vec())};
compute_shape_transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) {
auto original_shape = self.sizes().vec();
std::vector<int64_t> sizes{original_shape.begin(), original_shape.end()};
// Index may be negative, so we must normalize it. We create new variables
// instead of replacing the existing ones so that in the case of an error,
// the original values can be printed out.
int64_t dim0_norm = normalize_index(dim0, sizes.size());
int64_t dim1_norm = normalize_index(dim1, sizes.size());
// Verify dimensions are valid.
0 <= dim0_norm && dim0_norm < (int64_t)sizes.size(), "dim0 has value ",
dim0, ", but there are only ", sizes.size(), " tensor dimensions");
0 <= dim1_norm && dim1_norm < (int64_t)sizes.size(), "dim1 has value ",
dim1, ", but there are only ", sizes.size(), " tensor dimensions");
// Swap shapes at dimensions.
std::swap(sizes[dim0_norm], sizes[dim1_norm]);
return {Shape(self.scalar_type(), sizes)};
} // namespace lazy
} // namespace torch
Reference in New Issue