mirror of https://github.com/llvm/torch-mlir
251 lines
5.7 KiB
Python
251 lines
5.7 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
import torch
|
|
|
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
class ArangeIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(5)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeIntModule())
|
|
def ArangeIntModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeFloatModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(5.0)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeFloatModule())
|
|
def ArangeFloatModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeZeroElementOutputModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(0)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
|
|
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeStartIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(0, 5)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeStartIntModule())
|
|
def ArangeStartIntModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeStartFloatModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(0.0, 5.0)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
|
|
def ArangeStartFloatModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeNegativeStartIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(-10, 5)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
|
|
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeNegativeStartFloatModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(-1.4, 5.7)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
|
|
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeStartStepIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(0, 5, 1)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
|
|
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeStartStepFloatModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(-1, 5, 1.3)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
|
|
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeStartNegativeStepIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(10, 1, -2)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
|
|
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(-1, -15, -3.4)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
|
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeDtypeFloatModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(-1, 15, dtype=torch.float32)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
|
|
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeDtypeIntModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(0.2, 5.0, dtype=torch.int64)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
|
|
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
|
module.forward()
|
|
|
|
|
|
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
])
|
|
|
|
def forward(self):
|
|
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
|
|
|
|
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
|
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
|
module.forward()
|