[TBE] Add a test module for table batch embedding

This commit adds a test module specifically for table batch embedding
algorithm. This test case is in reference to the FBGEMM table batch
embedding:
https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/544/head snapshot-20220127.232
Gaurav Shukla 2022-01-21 22:51:11 +05:30
parent eb06d21765
commit 13b9fd62c6
3 changed files with 63 additions and 0 deletions

View File

@ -49,6 +49,7 @@ from . import arange
from . import constant_alloc
from . import threshold
from . import histogram_binning_calibration
from . import table_batch_embedding
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -0,0 +1,61 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
# Reference: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270
# Global parameters.
NUM_TABLES = 2
NUM_EMBEDDINGS = 10
EMBEDDING_DIM = 4
BATCH_SIZE = 4
BAG_SIZE = 3
class TableBatchEmbeddingModule(torch.nn.Module):
def __init__(self):
super(TableBatchEmbeddingModule, self).__init__()
self.num_tables = NUM_TABLES
self.num_embeddings = NUM_EMBEDDINGS
self.embedding_dim = EMBEDDING_DIM
self.batch_size = BATCH_SIZE
self.bag_size = BAG_SIZE
# Currently, pooling_mode is fixed to 'sum'.
self.nn_embedding_list = torch.nn.ModuleList([
torch.nn.EmbeddingBag(
self.num_embeddings, self.embedding_dim, mode="sum", sparse=False)
for i in range(self.num_tables)
])
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, indices, offsets):
indices_list = indices.view(self.num_tables, self.batch_size, self.bag_size)
final_output = torch.tensor([])
for i, nn_embedding in enumerate(self.nn_embedding_list):
indices = indices_list[i].view(-1)
output = nn_embedding(indices, offsets).view(self.batch_size, -1)
final_output = torch.cat((final_output, output), dim=1)
return final_output
@register_test_case(module_factory=lambda: TableBatchEmbeddingModule())
def TableBatchEmbeddingModule_basic(module, tu: TestUtils):
indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,))
offsets = torch.cumsum(
torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0)
module.forward(indices, offsets)

View File

@ -16,6 +16,7 @@
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"IouOfModule_basic",
"TableBatchEmbeddingModule_basic",
}
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS