torch-mlir/e2e_testing/torchscript/table_batch_embedding.py

62 lines
2.3 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
# Reference: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270
# Global parameters.
NUM_TABLES = 2
NUM_EMBEDDINGS = 10
EMBEDDING_DIM = 4
BATCH_SIZE = 4
BAG_SIZE = 3
class TableBatchEmbeddingModule(torch.nn.Module):
def __init__(self):
super(TableBatchEmbeddingModule, self).__init__()
self.num_tables = NUM_TABLES
self.num_embeddings = NUM_EMBEDDINGS
self.embedding_dim = EMBEDDING_DIM
self.batch_size = BATCH_SIZE
self.bag_size = BAG_SIZE
# Currently, pooling_mode is fixed to 'sum'.
self.nn_embedding_list = torch.nn.ModuleList([
torch.nn.EmbeddingBag(
self.num_embeddings, self.embedding_dim, mode="sum", sparse=False)
for i in range(self.num_tables)
])
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, indices, offsets):
indices_list = indices.view(self.num_tables, self.batch_size, self.bag_size)
final_output = torch.tensor([])
for i, nn_embedding in enumerate(self.nn_embedding_list):
indices = indices_list[i].view(-1)
output = nn_embedding(indices, offsets).view(self.batch_size, -1)
final_output = torch.cat((final_output, output), dim=1)
return final_output
@register_test_case(module_factory=lambda: TableBatchEmbeddingModule())
def TableBatchEmbeddingModule_basic(module, tu: TestUtils):
indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,))
offsets = torch.cumsum(
torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0)
module.forward(indices, offsets)