# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. import torch from torch_mlir_e2e_test.torchscript.framework import TestUtils from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== class ElementwiseGtFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.gt(x, 0.6) @register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule()) def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) # ============================================================================== class ElementwiseGtIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ]) def forward(self, x): return torch.gt(x, 10) @register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule()) def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4))) # ============================================================================== class ElementwiseGtMixed2ScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int32, True), ]) def forward(self, x): return torch.gt(x, 7) @register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule()) def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) # ============================================================================== class ElementwiseGeFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.ge(x, 0.6) @register_test_case(module_factory=lambda: ElementwiseGeFloatScalarModule()) def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) # ============================================================================== class ElementwiseGeIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ]) def forward(self, x): return torch.ge(x, 10) @register_test_case(module_factory=lambda: ElementwiseGeIntScalarModule()) def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4))) # ============================================================================== class ElementwiseGeMixedIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int32, True), ]) def forward(self, x): return torch.ge(x, 7) @register_test_case(module_factory=lambda: ElementwiseGeMixedIntScalarModule()) def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) # ============================================================================== class ElementwiseGeFloatIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.ge(x, 7) @register_test_case(module_factory=lambda: ElementwiseGeFloatIntScalarModule()) def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) # ============================================================================== class ElementwiseGtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ([-1], torch.float32, True), ]) def forward(self, x, y): return torch.gt(x, y) @register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule()) def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(5)) # ============================================================================== class ElementwiseGtIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ([-1], torch.int64, True), ]) def forward(self, x, y): return torch.gt(x, y) @register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule()) def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, ))) # ============================================================================== class ElementwiseLtFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.lt(x, 0.6) @register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule()) def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) # ============================================================================== class ElementwiseLtIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ]) def forward(self, x): return torch.lt(x, 0) @register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule()) def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4))) # ============================================================================== class ElementwiseLtDiffWidthScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int32, True), ]) def forward(self, x): return torch.lt(x, 2) @register_test_case( module_factory=lambda: ElementwiseLtDiffWidthScalarModule()) def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) # ============================================================================== class ElementwiseLeFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.le(x, 0.6) @register_test_case(module_factory=lambda: ElementwiseLeFloatScalarModule()) def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) # ============================================================================== class ElementwiseLeIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ]) def forward(self, x): return torch.le(x, 10) @register_test_case(module_factory=lambda: ElementwiseLeIntScalarModule()) def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4))) # ============================================================================== class ElementwiseLeMixedIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int32, True), ]) def forward(self, x): return torch.le(x, 7) @register_test_case(module_factory=lambda: ElementwiseLeMixedIntScalarModule()) def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) # ============================================================================== class ElementwiseLeFloatIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.le(x, 7) @register_test_case(module_factory=lambda: ElementwiseLeFloatIntScalarModule()) def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) # ============================================================================== class ElementwiseLtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ([-1], torch.float32, True), ]) def forward(self, x, y): return torch.lt(x, y) @register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule()) def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(5)) # ============================================================================== class ElementwiseLtIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ([-1], torch.int64, True), ]) def forward(self, x, y): return torch.lt(x, y) @register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule()) def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, ))) # ============================================================================== class ElementwiseEqFloatScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) def forward(self, x): return torch.eq(x, 6.0) @register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule()) def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32)) # ============================================================================== class ElementwiseEqIntScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ]) def forward(self, x): return torch.eq(x, 2) @register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule()) def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(2, 4, (5, 8))) # ============================================================================== class ElementwiseEqDiffWidthScalarModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int32, True), ]) def forward(self, x): return torch.eq(x, 2) @register_test_case( module_factory=lambda: ElementwiseEqDiffWidthScalarModule()) def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32)) # ============================================================================== class ElementwiseEqFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.float32, True), ([-1], torch.float32, True), ]) def forward(self, x, y): return torch.eq(x, y) @register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule()) def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils): module.forward( torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32), torch.tensor([1.0, 2.4, 6.0]).to(torch.float32)) # ============================================================================== class ElementwiseEqIntTensorModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([-1, -1], torch.int64, True), ([-1], torch.int64, True), ]) def forward(self, x, y): return torch.eq(x, y) @register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule()) def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))