mirror of https://github.com/llvm/torch-mlir
236 lines
6.5 KiB
Python
236 lines
6.5 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
import torch
|
|
|
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[0:5:1, 1:3:1, 2:4:1]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceModule())
|
|
def SliceModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,4,7))
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
|
result = x[:8, :5, 8:]
|
|
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
|
|
return torch.cat((result,cat_tensor), dim=2)
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
|
|
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,4,7))
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[:-8,-7:,:]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
|
|
def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,4,7))
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[-8:3:1, 1:3:1, 2:4:1]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
|
|
def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,4,7))
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
class SliceEndSleStartModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
|
result = x[:, 4:3, :]
|
|
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
|
|
return torch.cat((result, cat_tensor), dim=1)
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
|
|
def SliceEndSleStartModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,4,7))
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
class SliceStartEqEndModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
|
result = x[5:5, :, :]
|
|
cat_tensor = torch.ones((1,4,7), dtype=torch.float32)
|
|
return torch.cat((result, cat_tensor), dim=0)
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
|
|
def SliceStartEqEndModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,4,7))
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceSizeTwoStepModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[0:5:2, 0:3:2, 0:4:2]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceSizeTwoStepModule())
|
|
def SliceSizeTwoStepModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(10,5,17))
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceNegIdxModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[:-1, -2:-1]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceNegIdxModule())
|
|
def SliceNegIdxModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(3,9))
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceSingleIdxModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[0]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceSingleIdxModule())
|
|
def SliceSingleIdxModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,8))
|
|
|
|
# ==============================================================================
|
|
|
|
class SliceWholeTensorModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1], torch.float32, True),
|
|
])
|
|
def forward(self, x):
|
|
return x[:, :]
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SliceWholeTensorModule())
|
|
def SliceWholeTensorModule_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(6,8))
|
|
|
|
# ==============================================================================
|
|
|
|
class SelectIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1], torch.int64, True),
|
|
])
|
|
def forward(self, x):
|
|
return x.select(0,0)
|
|
|
|
|
|
@register_test_case(module_factory=lambda: SelectIntModule())
|
|
def SelectIntModule_basic(module, tu: TestUtils):
|
|
module.forward(torch.randint(10, (5,5)))
|
|
|
|
# ==============================================================================
|
|
|