From de5b380143ddce097e235a56acc941063aeeacdb Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 10 May 2022 09:03:41 -0400 Subject: [PATCH] Bert example and relevant shape inference functions (#831) --- examples/ltc_backend_bert.py | 135 ++++++++++ .../base_lazy_backend/LazyShapeInference.cpp | 241 ++++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 examples/ltc_backend_bert.py diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py new file mode 100644 index 000000000..9278b1105 --- /dev/null +++ b/examples/ltc_backend_bert.py @@ -0,0 +1,135 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +""" +Runs a training of the Bert model using the Lazy Tensor Core with the +example Torch MLIR backend. + +Most of the code in this example was copied from the wonderful tutorial + https://huggingface.co/transformers/training.html#fine-tuning-in-native-pytorch + +Based on LTC code samples by ramiro050 + https://github.com/ramiro050/lazy-tensor-samples +""" + +import argparse +import torch +from datasets import load_dataset +from datasets.dataset_dict import DatasetDict +from torch.utils.data import DataLoader +from transformers import BertForSequenceClassification, \ + BertTokenizer, AdamW, get_scheduler +from typing import List + + +def tokenize_dataset(dataset: DatasetDict) -> DatasetDict: + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + + def tokenize_function(examples): + return tokenizer(examples["text"], padding="max_length", + truncation=True) + + tokenized_datasets = dataset.map(tokenize_function, batched=True) + tokenized_datasets = tokenized_datasets.remove_columns(['text']) + tokenized_datasets = tokenized_datasets.rename_column('label', 'labels') + tokenized_datasets.set_format('torch') + + return tokenized_datasets + + +def train(model: BertForSequenceClassification, + num_epochs: int, + num_training_steps: int, + train_dataloader: DataLoader, + device: torch.device, + do_mark_step: bool) -> List[torch.Tensor]: + optimizer = AdamW(model.parameters(), lr=5e-5) + lr_scheduler = get_scheduler('linear', optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_training_steps) + + model.train() + losses = [] + for _ in range(num_epochs): + for batch in train_dataloader: + batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + loss.backward() + losses.append(loss) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if do_mark_step and 'lazy' in str(model.device): + print("Calling Mark Step") + torch._lazy.mark_step() + + return losses + + +def main(device, lower_only): + if device in ("TS", "MLIR_EXAMPLE"): + import torch._lazy + + if device == "TS": + import torch._lazy.ts_backend + + torch._lazy.ts_backend.init() + + elif device == "MLIR_EXAMPLE": + import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend + + ltc_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = device.lower() + + tokenized_datasets = tokenize_dataset(load_dataset('imdb')) + small_train_dataset = tokenized_datasets['train'].shuffle(seed=42) \ + .select(range(2)) + + train_dataloader = DataLoader(small_train_dataset, shuffle=True, + batch_size=8) + model = BertForSequenceClassification.from_pretrained('bert-base-cased', + num_labels=2) + model.to(device) + + num_epochs = 3 + num_training_steps = num_epochs * len(train_dataloader) + losses = train(model, num_epochs, + num_training_steps, train_dataloader, device, not lower_only) + + if lower_only: + print('\nJIT Graph:') + import torch._C + graph_str = torch._C._lazy._get_tensors_backend([losses[0]]) + print(graph_str) + else: + # Execute computation + print('Loss: ', losses) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--device", + type=str.upper, + choices=["CPU", "TS", "MLIR_EXAMPLE"], + default="MLIR_EXAMPLE", + help="The device type", + ) + parser.add_argument( + "-l", + "--lower_only", + action='store_true', + default=False, + help="Only get backend printout -- do not execute computation", + ) + args = parser.parse_args() + main(args.device, args.lower_only) diff --git a/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp index f36692297..d57fca159 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp @@ -9,10 +9,127 @@ #include "LazyShapeInference.h" #include "../utils/exception.h" +#include namespace torch { namespace lazy { +// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future. + +// Turns any negative index positive (assuming it's valid) +int64_t normalize_index(int64_t index, unsigned dims) { + return index < 0 ? (int64_t)dims + index : index; +} + +std::vector +compute_shape_dropout(const at::Tensor& input, double p, bool train) { + return {Shape(input.scalar_type(), input.sizes().vec())}; +} + +std::vector compute_shape_layer_norm( + const at::Tensor& input, at::IntArrayRef normalized_shape, + const c10::optional& weight, + const c10::optional& bias, double eps, bool cudnn_enable) { + return {Shape(input.scalar_type(), input.sizes().vec())}; +} + +std::vector +compute_shape_matmul(const at::Tensor& self, const at::Tensor& other) { + std::vector sizes; + + auto self_sizes = self.sizes().vec(); + auto other_sizes = other.sizes().vec(); + + // For tensors with dimensions >2, the leading dimensions are for batch info. + // The last 2 (or 1 in the case of a single dim tensor) dimensions are the + // matrix dimensions themselves, which is checked to ensure the matmul op + // is legal. + // + // Example: + // [1, 2, 3, 4] -> [1, 2] batch dims and [3, 4] matrix + // [1, 4, 5] -> [1] batch dims and [4, 5] matrix + // [4, 5] -> [] batch dims and [4, 5] matrix + // [5] -> [] batch dims and [5] matrix + // + // We'll start by splitting the shapes as described above. + auto partition_shape = [](at::ArrayRef sizes) { + if (sizes.size() <= 2) { + return std::make_pair( + std::vector(), + std::vector(sizes.begin(), sizes.end())); + } else { + std::size_t partition_idx = sizes.size() - 2; + return std::make_pair( + std::vector(sizes.begin(), sizes.begin() + partition_idx), + std::vector(sizes.begin() + partition_idx, sizes.end())); + } + }; + auto [self_batch_sizes, self_matrix_sizes] = partition_shape(self_sizes); + auto [other_batch_sizes, other_matrix_sizes] = partition_shape(other_sizes); + + // Insert batch dimensions. + // The final list of sizes will be based on the tensor w/ more dims. + // Individual dimension sizes are "right justified" as we iterate thru + // to pick the larger dimension between them. + // 0 1 1 3 4 + // 5 1 2 + // --------- + // 0 1 5 3 4 <- Result + int64_t self_size, other_size; + std::size_t num_batch_dim = + std::max(self_batch_sizes.size(), other_batch_sizes.size()); + auto get_batch_dim = [&](std::vector batch_sizes, std::size_t dim) { + long idx = dim - num_batch_dim + batch_sizes.size(); + // Negative index means out of bounds, which defaults to a dim size of 1. + return idx < 0 ? 1 : batch_sizes[idx]; + }; + for (std::size_t i = 0; i < num_batch_dim; i++) { + self_size = get_batch_dim(self_batch_sizes, i); + other_size = get_batch_dim(other_batch_sizes, i); + + TORCH_CHECK( + self_size == 1 || other_size == 1 || self_size == other_size, + "At trailing dimension ", i, ", expected for dimensions ", + "to either match or have one of them equal one, but got ", self_size, + " and ", other_size, " instead!"); + + sizes.push_back(std::max(self_size, other_size)); + } + + // Keep track of the inner dimensions of matmul to validate op is valid. + std::pair inner_sizes; + if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 1) { + // Dot-Product -- scalar output, so no dimensions inserted + inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]); + } else if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 2) { + // Vector-Matrix product (m) @ (m, n) -> (n) + inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]); + + sizes.push_back(other_matrix_sizes[1]); + } else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 1) { + // Matrix-Vector product (m, n) @ (n) -> (m) + inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]); + + sizes.push_back(self_matrix_sizes[0]); + } else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 2) { + // Matrix-Matrix product (m, n) @ (n, o) -> (m, o) + inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]); + + sizes.push_back(self_matrix_sizes[0]); + sizes.push_back(other_matrix_sizes[1]); + } else { + // By this time, self_matrix_sizes and other_matrix_sizes should have at + // most 2 dims, so if this is executed something has gone wrong... + TORCH_CHECK(false, "Invalid matmul shape combination!"); + } + + TORCH_CHECK( + inner_sizes.first == inner_sizes.second, "Inner dimension of matrix (", + inner_sizes.first, ") does not ", "match (", inner_sizes.second, ")!"); + + return {Shape(self.scalar_type(), sizes)}; +} + std::vector compute_shape_native_batch_norm( const at::Tensor& input, const c10::optional& weight, const c10::optional& bias, @@ -33,5 +150,129 @@ std::vector compute_shape_native_batch_norm( return shapes; } +std::vector +compute_shape_reshape(const at::Tensor& self, at::IntArrayRef shape) { + // Make a copy of the desired output shape. + std::vector sizes(shape.begin(), shape.end()); + + // Product of all sizes in input shape is the number of entries in tensor. + int64_t num_entries = 1; + for (int64_t i : self.sizes().vec()) { + num_entries *= i; + } + + // Validate the number of entries in the desired shape. If there is a wildcard + // dimension, we need to find it now in order to populate it. + long wildcard_idx = -1; + int64_t num_concrete_entries = 1; + for (std::size_t idx = 0; idx < sizes.size(); idx++) { + if (sizes[idx] != -1) { + num_concrete_entries *= sizes[idx]; + } else { + TORCH_CHECK(wildcard_idx == -1, "only one dimension can be inferred"); + wildcard_idx = idx; + } + } + + if (wildcard_idx == -1) { + // No wildcard, the shape should already be known. + TORCH_CHECK( + num_entries == num_concrete_entries, "shape `[", sizes, + "]` is invalid for input of size ", num_concrete_entries); + } else { + // There is one dimension which is not explicitly declared -- we need to + // infer. + TORCH_CHECK( + num_entries % num_concrete_entries == 0, "shape `[", sizes, + "]` is invalid for input of size ", num_concrete_entries); + + sizes[wildcard_idx] = num_entries / num_concrete_entries; + } + + return {Shape(self.scalar_type(), sizes)}; +} + +std::vector compute_shape_rsub( + const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) { + // Since other is scalar, the result will match tensor shape. + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector +compute_shape_select(const at::Tensor& self, int64_t dim, int64_t index) { + auto original_shape = self.sizes().vec(); + std::vector sizes(original_shape.begin(), original_shape.end()); + + TORCH_CHECK( + dim < (int64_t)sizes.size(), "Dimension ", dim, + " is out of bounds for tensor with ", sizes.size(), " dimensions!"); + TORCH_CHECK( + index < sizes[dim], "Index ", index, + " is out of bounds for dimension of size ", sizes[dim]); + sizes.erase(sizes.begin() + dim); + + return {Shape(self.scalar_type(), sizes)}; +} + +std::vector compute_shape_slice( + const at::Tensor& self, int64_t dim, c10::optional start, + c10::optional end, int64_t step) { + auto original_shape = self.sizes().vec(); + std::vector sizes(original_shape.begin(), original_shape.end()); + + int64_t dim_size = sizes[dim]; + + // Index may be negative, so we must normalize it. + int64_t start_norm = normalize_index(start.value(), dim_size); + int64_t end_norm = normalize_index(end.value(), dim_size); + + if (start_norm >= end_norm || start_norm >= dim_size || end_norm <= 0) { + // Slice is out of bounds, nothing in range. + sizes[dim] = 0; + } else { + // Clamp upper and lower bound to valid indices. + start_norm = std::max((int64_t)0, start_norm); + end_norm = std::min(dim_size, end_norm); + + // Final size is determined by step and interval size. + sizes[dim] = std::ceil((double)(end_norm - start_norm) / (double)step); + } + + return {Shape(self.scalar_type(), sizes)}; +} + +std::vector compute_shape_softmax( + const at::Tensor& self, int64_t dim, c10::optional dtype) { + if (dtype.has_value()) { + return {Shape(dtype.value(), self.sizes().vec())}; + } + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector +compute_shape_transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) { + auto original_shape = self.sizes().vec(); + std::vector sizes{original_shape.begin(), original_shape.end()}; + + // Index may be negative, so we must normalize it. We create new variables + // instead of replacing the existing ones so that in the case of an error, + // the original values can be printed out. + int64_t dim0_norm = normalize_index(dim0, sizes.size()); + int64_t dim1_norm = normalize_index(dim1, sizes.size()); + + // Verify dimensions are valid. + TORCH_CHECK( + 0 <= dim0_norm && dim0_norm < (int64_t)sizes.size(), "dim0 has value ", + dim0, ", but there are only ", sizes.size(), " tensor dimensions"); + TORCH_CHECK( + 0 <= dim1_norm && dim1_norm < (int64_t)sizes.size(), "dim1 has value ", + dim1, ", but there are only ", sizes.size(), " tensor dimensions"); + + // Swap shapes at dimensions. + std::swap(sizes[dim0_norm], sizes[dim1_norm]); + + return {Shape(self.scalar_type(), sizes)}; +} + } // namespace lazy } // namespace torch