torch-mlir/e2e_testing/torchscript/arange.py

251 lines
5.7 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ArangeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(5)
@register_test_case(module_factory=lambda: ArangeIntModule())
def ArangeIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(5.0)
@register_test_case(module_factory=lambda: ArangeFloatModule())
def ArangeFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeZeroElementOutputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0)
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0, 5)
@register_test_case(module_factory=lambda: ArangeStartIntModule())
def ArangeStartIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0.0, 5.0)
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
def ArangeStartFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeNegativeStartIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-10, 5)
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeNegativeStartFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1.4, 5.7)
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartStepIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0, 5, 1)
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartStepFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, 5, 1.3)
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartNegativeStepIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(10, 1, -2)
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, -15, -3.4)
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeDtypeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, 15, dtype=torch.float32)
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeDtypeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(0.2, 5.0, dtype=torch.int64)
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward()