From 13b9fd62c61ac32f5599b283035ad1685a7a5074 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Fri, 21 Jan 2022 22:51:11 +0530 Subject: [PATCH] [TBE] Add a test module for table batch embedding This commit adds a test module specifically for table batch embedding algorithm. This test case is in reference to the FBGEMM table batch embedding: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270 Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/main.py | 1 + .../torchscript/table_batch_embedding.py | 61 +++++++++++++++++++ e2e_testing/torchscript/xfail_sets.py | 1 + 3 files changed, 63 insertions(+) create mode 100644 e2e_testing/torchscript/table_batch_embedding.py diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index f08572de3..4cabc0bb1 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -49,6 +49,7 @@ from . import arange from . import constant_alloc from . import threshold from . import histogram_binning_calibration +from . import table_batch_embedding def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/e2e_testing/torchscript/table_batch_embedding.py b/e2e_testing/torchscript/table_batch_embedding.py new file mode 100644 index 000000000..1f74c9dc8 --- /dev/null +++ b/e2e_testing/torchscript/table_batch_embedding.py @@ -0,0 +1,61 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +# Reference: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270 + +# Global parameters. +NUM_TABLES = 2 +NUM_EMBEDDINGS = 10 +EMBEDDING_DIM = 4 +BATCH_SIZE = 4 +BAG_SIZE = 3 + + +class TableBatchEmbeddingModule(torch.nn.Module): + def __init__(self): + super(TableBatchEmbeddingModule, self).__init__() + self.num_tables = NUM_TABLES + self.num_embeddings = NUM_EMBEDDINGS + self.embedding_dim = EMBEDDING_DIM + self.batch_size = BATCH_SIZE + self.bag_size = BAG_SIZE + # Currently, pooling_mode is fixed to 'sum'. + self.nn_embedding_list = torch.nn.ModuleList([ + torch.nn.EmbeddingBag( + self.num_embeddings, self.embedding_dim, mode="sum", sparse=False) + for i in range(self.num_tables) + ]) + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, indices, offsets): + indices_list = indices.view(self.num_tables, self.batch_size, self.bag_size) + final_output = torch.tensor([]) + for i, nn_embedding in enumerate(self.nn_embedding_list): + indices = indices_list[i].view(-1) + output = nn_embedding(indices, offsets).view(self.batch_size, -1) + final_output = torch.cat((final_output, output), dim=1) + return final_output + + +@register_test_case(module_factory=lambda: TableBatchEmbeddingModule()) +def TableBatchEmbeddingModule_basic(module, tu: TestUtils): + indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,)) + offsets = torch.cumsum( + torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0) + module.forward(indices, offsets) + diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 07f2dda1d..21dba05d4 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -16,6 +16,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "IouOfModule_basic", + "TableBatchEmbeddingModule_basic", } REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS