[LINALG] Add E2E support for `aten.[le|ge].Scalar` ops

- This commit adds lowering of `aten.le.Scalar` and `aten.ge.Scalar` ops
  as a part of `convert-torch-to-linalg` pass.
- It also creates a new test script `elementwise_comparison.py` for all
  element-wise comparison ops.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/600/head
Gaurav Shukla 2022-02-15 01:16:38 +05:30
parent 413e6000d2
commit 41acde599b
7 changed files with 607 additions and 326 deletions

View File

@ -18,7 +18,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ElementwiseUnaryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -36,10 +35,8 @@ class ElementwiseUnaryModule(torch.nn.Module):
def ElementwiseUnaryModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseBinaryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -58,10 +55,8 @@ class ElementwiseBinaryModule(torch.nn.Module):
def ElementwiseBinaryModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(4))
# ==============================================================================
class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -75,15 +70,14 @@ class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
def forward(self, a, b):
return a * b
@register_test_case(
module_factory=lambda: ElementwiseBinaryStaticShapeModule())
def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 3, 1), tu.rand(4, 3, 1, 2))
# ==============================================================================
class ElementwiseTernaryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -103,10 +97,8 @@ class ElementwiseTernaryModule(torch.nn.Module):
def ElementwiseTernaryModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
# ==============================================================================
class ElementwiseWhereSelfModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -126,10 +118,8 @@ class ElementwiseWhereSelfModule(torch.nn.Module):
def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
# ==============================================================================
# Addition is an interesting special case of a binary op, because under the hood
# it carries a third scalar "alpha" parameter, which needs special handling.
class ElementwiseAddModule(torch.nn.Module):
@ -150,10 +140,8 @@ class ElementwiseAddModule(torch.nn.Module):
def ElementwiseAddModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand())
# ==============================================================================
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -173,10 +161,8 @@ class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand())
# ==============================================================================
class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -197,10 +183,8 @@ class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3))
# ==============================================================================
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -221,7 +205,6 @@ def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseReluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -240,6 +223,7 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)
# ==============================================================================
class ElementwiseLeakyReluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -259,7 +243,6 @@ def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseGeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -278,10 +261,8 @@ class ElementwiseGeluModule(torch.nn.Module):
def ElementwiseGeluModule_basic(module, tu: TestUtils):
module.forward(2 * tu.rand(5, 3) - 0.5)
# ==============================================================================
class ElementwiseSigmoidModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -301,7 +282,6 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseMinimumModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -321,10 +301,8 @@ def ElementwiseMinimumModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(3, 5))
module.forward(tu.nans(3, 5), tu.rand(3, 5))
# ==============================================================================
class ElementwiseMaximumModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -346,297 +324,6 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseGtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.gt(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule())
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseGtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.gt(x, 10)
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)))
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.gt(x, 7)
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule())
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
class ElementwiseGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
# ==============================================================================
class ElementwiseLtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.lt(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule())
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ElementwiseLtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.lt(x, 0)
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)))
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.lt(x, 2)
@register_test_case(
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
class ElementwiseLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule())
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
class ElementwiseLtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
# ==============================================================================
class ElementwiseEqFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.eq(x, 6.0)
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
class ElementwiseEqIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.eq(x, 2)
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (5, 8)))
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.eq(x, 2)
@register_test_case(
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32))
class ElementwiseEqFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.eq(x, y)
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
class ElementwiseEqIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.eq(x, y)
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
# ==============================================================================
class ElementwiseClampModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -664,6 +351,7 @@ def ElementwiseClampModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, low=-10, high=10))
# ==============================================================================
class RsubModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -676,10 +364,13 @@ class RsubModule(torch.nn.Module):
def forward(self, x):
return torch.rsub(x, 3.0, alpha=1.0)
@register_test_case(module_factory=lambda: RsubModule())
def RsubModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class RsubModule_noalpha(torch.nn.Module):
def __init__(self):
super().__init__()
@ -692,6 +383,7 @@ class RsubModule_noalpha(torch.nn.Module):
def forward(self, x):
return torch.rsub(x, 2.0)
@register_test_case(module_factory=lambda: RsubModule_noalpha())
def RsubModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
@ -710,10 +402,12 @@ class ElementwiseMulScalarIntModule(torch.nn.Module):
def forward(self, x):
return torch.mul(x, 4)
@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule())
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
# ==============================================================================
class ElementwiseMulScalarFloatModule(torch.nn.Module):
def __init__(self):
@ -727,10 +421,12 @@ class ElementwiseMulScalarFloatModule(torch.nn.Module):
def forward(self, x):
return torch.mul(x, 100.0)
@register_test_case(module_factory=lambda: ElementwiseMulScalarFloatModule())
def ElementwiseMulScalarModule_float(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseMulScalarModule(torch.nn.Module):
def __init__(self):
@ -744,10 +440,12 @@ class ElementwiseMulScalarModule(torch.nn.Module):
def forward(self, x):
return torch.mul(x, 8.0)
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseMulTensorFloatModule(torch.nn.Module):
def __init__(self):
@ -767,6 +465,7 @@ class ElementwiseMulTensorFloatModule(torch.nn.Module):
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
# ==============================================================================
class ElementwiseMulTensorIntModule(torch.nn.Module):
def __init__(self):
@ -787,8 +486,8 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
module.forward(
torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4]))
# ==============================================================================
class ElementwiseLogModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -806,6 +505,7 @@ class ElementwiseLogModule(torch.nn.Module):
def ElementwiseLogModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseSqrtModule(torch.nn.Module):
def __init__(self):
@ -820,10 +520,13 @@ class ElementwiseSqrtModule(torch.nn.Module):
def forward(self, a):
return torch.sqrt(a)
@register_test_case(module_factory=lambda: ElementwiseSqrtModule())
def ElementwiseSqrtModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseFloorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -836,10 +539,13 @@ class ElementwiseFloorModule(torch.nn.Module):
def forward(self, a):
return torch.floor(a)
@register_test_case(module_factory=lambda: ElementwiseFloorModule())
def ElementwiseFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseCeilModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -852,10 +558,13 @@ class ElementwiseCeilModule(torch.nn.Module):
def forward(self, a):
return torch.ceil(a)
@register_test_case(module_factory=lambda: ElementwiseCeilModule())
def ElementwiseCeilModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwisePowModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -868,6 +577,7 @@ class ElementwisePowModule(torch.nn.Module):
def forward(self, a):
return torch.pow(a, 2.0)
@register_test_case(module_factory=lambda: ElementwisePowModule())
def ElementwisePowModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
@ -886,10 +596,13 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
def forward(self, x):
return x.to(torch.int64)
@register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module())
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseToDtypeIdentityModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -902,10 +615,13 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module):
def forward(self, x):
return x.to(torch.float32, False, False)
@register_test_case(module_factory=lambda: ElementwiseToDtypeIdentityModule())
def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLog2Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -918,10 +634,13 @@ class ElementwiseLog2Module(torch.nn.Module):
def forward(self, a):
return torch.log2(a)
@register_test_case(module_factory=lambda: ElementwiseLog2Module())
def ElementwiseLog2Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseRsqrtModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -935,6 +654,7 @@ class ElementwiseRsqrtModule(torch.nn.Module):
def forward(self, a):
return torch.rsqrt(a)
@register_test_case(module_factory=lambda: ElementwiseRsqrtModule())
def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
@ -953,6 +673,7 @@ class ElementwiseAbsModule(torch.nn.Module):
def forward(self, a):
return torch.abs(a)
@register_test_case(module_factory=lambda: ElementwiseAbsModule())
def ElementwiseAbsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0))
@ -971,6 +692,7 @@ class ElementwiseReciprocalModule(torch.nn.Module):
def forward(self, a):
return torch.reciprocal(a)
@register_test_case(module_factory=lambda: ElementwiseReciprocalModule())
def ElementwiseReciprocalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4))
@ -994,6 +716,7 @@ class ElementwiseDivScalarModule(torch.nn.Module):
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseDivTensorFloatModule(torch.nn.Module):
def __init__(self):
@ -1013,10 +736,8 @@ class ElementwiseDivTensorFloatModule(torch.nn.Module):
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
# ==============================================================================
class ElementwiseAndIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1037,6 +758,7 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
torch.randint(-10, 10, (3, 4)).to(torch.int32),
torch.randint(-10, 10, (3, 4)))
# ==============================================================================
class ElementwiseSubScalarIntModule(torch.nn.Module):
def __init__(self):
@ -1055,6 +777,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseSubScalarFloatModule(torch.nn.Module):
def __init__(self):
@ -1068,10 +791,13 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
def forward(self, x):
return torch.sub(x, 2.1)
@register_test_case(module_factory=lambda: ElementwiseSubScalarFloatModule())
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseAddScalarInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1084,10 +810,12 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module):
def forward(self, x):
return torch.add(x, 3.0)
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module())
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
# ==============================================================================
class ElementwiseAddScalarIntModule(torch.nn.Module):
def __init__(self):
@ -1101,10 +829,12 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
def forward(self, x):
return torch.add(x, 3.0)
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
# ==============================================================================
class ElementwiseAddScalarFloatModule(torch.nn.Module):
def __init__(self):
@ -1123,6 +853,7 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module):
def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseCloneModule(torch.nn.Module):
def __init__(self):

View File

@ -0,0 +1,459 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ElementwiseGtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.gt(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule())
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseGtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.gt(x, 10)
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)))
# ==============================================================================
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.gt(x, 7)
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
# ==============================================================================
class ElementwiseGeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ge(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseGeFloatScalarModule())
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseGeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.ge(x, 10)
@register_test_case(module_factory=lambda: ElementwiseGeIntScalarModule())
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)))
# ==============================================================================
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.ge(x, 7)
@register_test_case(module_factory=lambda: ElementwiseGeMixedIntScalarModule())
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
# ==============================================================================
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ge(x, 7)
@register_test_case(module_factory=lambda: ElementwiseGeFloatIntScalarModule())
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule())
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
# ==============================================================================
class ElementwiseGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
# ==============================================================================
class ElementwiseLtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.lt(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule())
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.lt(x, 0)
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)))
# ==============================================================================
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.lt(x, 2)
@register_test_case(
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
# ==============================================================================
class ElementwiseLeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.le(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseLeFloatScalarModule())
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.le(x, 10)
@register_test_case(module_factory=lambda: ElementwiseLeIntScalarModule())
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)))
# ==============================================================================
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.le(x, 7)
@register_test_case(module_factory=lambda: ElementwiseLeMixedIntScalarModule())
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
# ==============================================================================
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.le(x, 7)
@register_test_case(module_factory=lambda: ElementwiseLeFloatIntScalarModule())
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule())
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
# ==============================================================================
class ElementwiseLtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
# ==============================================================================
class ElementwiseEqFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.eq(x, 6.0)
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
# ==============================================================================
class ElementwiseEqIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.eq(x, 2)
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (5, 8)))
# ==============================================================================
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.eq(x, 2)
@register_test_case(
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32))
# ==============================================================================
class ElementwiseEqFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, x, y):
return torch.eq(x, y)
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
# ==============================================================================
class ElementwiseEqIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, y):
return torch.eq(x, y)
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))

View File

@ -42,6 +42,7 @@ from . import matmul
from . import reshape_like
from . import scalar
from . import scalar_comparison
from . import elementwise_comparison
from . import squeeze
from . import slice_like
from . import nll_loss

View File

@ -904,6 +904,36 @@ def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
}
def Torch_AtenLeScalarOp : Torch_Op<"aten.le.Scalar", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::le.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
}
def Torch_AtenLe_ScalarOp : Torch_Op<"aten.le_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::le_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
}
def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -1990,6 +1990,36 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
Type dtype = geScalar.self().getType().cast<BaseTensorType>().getDtype();
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
// can be refactored.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
geScalar.emitError(
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
payloadArgs[0], otherPromoted);
}
geScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
Type dtype = eqScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
@ -2040,6 +2070,34 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
Type dtype = leScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that
// can be refactored.
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
leScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
payloadArgs[0], otherPromoted);
}
leScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Type dtype = converter->convertType(whereSelf.getType())
.cast<RankedTensorType>()
@ -2455,10 +2513,11 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(op))
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -4575,9 +4634,9 @@ public:
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp>();
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -235,7 +235,7 @@ public:
// These comparison ops return a tensor with 1-bit integer dtype.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
AtenNeScalarOp>(op)) {
AtenLeScalarOp, AtenNeScalarOp>(op)) {
auto operand = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());

View File

@ -475,6 +475,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",