mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.[le|ge].Scalar` ops
- This commit adds lowering of `aten.le.Scalar` and `aten.ge.Scalar` ops as a part of `convert-torch-to-linalg` pass. - It also creates a new test script `elementwise_comparison.py` for all element-wise comparison ops. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/600/head
parent
413e6000d2
commit
41acde599b
|
@ -18,7 +18,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnaryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -36,10 +35,8 @@ class ElementwiseUnaryModule(torch.nn.Module):
|
|||
def ElementwiseUnaryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseBinaryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -58,10 +55,8 @@ class ElementwiseBinaryModule(torch.nn.Module):
|
|||
def ElementwiseBinaryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -75,15 +70,14 @@ class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
|
|||
def forward(self, a, b):
|
||||
return a * b
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseBinaryStaticShapeModule())
|
||||
def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 3, 1), tu.rand(4, 3, 1, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseTernaryModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -103,10 +97,8 @@ class ElementwiseTernaryModule(torch.nn.Module):
|
|||
def ElementwiseTernaryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereSelfModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -126,10 +118,8 @@ class ElementwiseWhereSelfModule(torch.nn.Module):
|
|||
def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# Addition is an interesting special case of a binary op, because under the hood
|
||||
# it carries a third scalar "alpha" parameter, which needs special handling.
|
||||
class ElementwiseAddModule(torch.nn.Module):
|
||||
|
@ -150,10 +140,8 @@ class ElementwiseAddModule(torch.nn.Module):
|
|||
def ElementwiseAddModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -173,10 +161,8 @@ class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
|||
def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -197,10 +183,8 @@ class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
|
|||
def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -221,7 +205,6 @@ def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseReluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -240,6 +223,7 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeakyReluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -259,7 +243,6 @@ def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -278,10 +261,8 @@ class ElementwiseGeluModule(torch.nn.Module):
|
|||
def ElementwiseGeluModule_basic(module, tu: TestUtils):
|
||||
module.forward(2 * tu.rand(5, 3) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSigmoidModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -301,7 +282,6 @@ def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMinimumModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -321,10 +301,8 @@ def ElementwiseMinimumModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
module.forward(tu.nans(3, 5), tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMaximumModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -346,297 +324,6 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 0.6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule())
|
||||
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class ElementwiseGtIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 10)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
|
||||
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
|
||||
|
||||
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
|
||||
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
|
||||
|
||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule())
|
||||
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
|
||||
class ElementwiseGtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
|
||||
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0.6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule())
|
||||
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class ElementwiseLtIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
|
||||
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
|
||||
|
||||
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
||||
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
|
||||
|
||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule())
|
||||
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
|
||||
class ElementwiseLtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
|
||||
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 6.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
|
||||
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
|
||||
class ElementwiseEqIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
|
||||
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (5, 8)))
|
||||
|
||||
|
||||
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
||||
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32))
|
||||
|
||||
|
||||
class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
|
||||
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
|
||||
|
||||
|
||||
class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
|
||||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseClampModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -664,6 +351,7 @@ def ElementwiseClampModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 5, low=-10, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class RsubModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -676,10 +364,13 @@ class RsubModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.rsub(x, 3.0, alpha=1.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubModule())
|
||||
def RsubModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class RsubModule_noalpha(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -692,6 +383,7 @@ class RsubModule_noalpha(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.rsub(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubModule_noalpha())
|
||||
def RsubModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
@ -710,10 +402,12 @@ class ElementwiseMulScalarIntModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.mul(x, 4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule())
|
||||
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMulScalarFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -727,10 +421,12 @@ class ElementwiseMulScalarFloatModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.mul(x, 100.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarFloatModule())
|
||||
def ElementwiseMulScalarModule_float(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMulScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -744,10 +440,12 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.mul(x, 8.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -767,6 +465,7 @@ class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
|||
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMulTensorIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -787,8 +486,8 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLogModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -806,6 +505,7 @@ class ElementwiseLogModule(torch.nn.Module):
|
|||
def ElementwiseLogModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSqrtModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -820,10 +520,13 @@ class ElementwiseSqrtModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.sqrt(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSqrtModule())
|
||||
def ElementwiseSqrtModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseFloorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -836,10 +539,13 @@ class ElementwiseFloorModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.floor(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseFloorModule())
|
||||
def ElementwiseFloorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseCeilModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -852,10 +558,13 @@ class ElementwiseCeilModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.ceil(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCeilModule())
|
||||
def ElementwiseCeilModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwisePowModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -868,6 +577,7 @@ class ElementwisePowModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.pow(a, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwisePowModule())
|
||||
def ElementwisePowModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
@ -886,10 +596,13 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module())
|
||||
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -902,10 +615,13 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return x.to(torch.float32, False, False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseToDtypeIdentityModule())
|
||||
def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLog2Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -918,10 +634,13 @@ class ElementwiseLog2Module(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.log2(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLog2Module())
|
||||
def ElementwiseLog2Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseRsqrtModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -935,6 +654,7 @@ class ElementwiseRsqrtModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.rsqrt(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRsqrtModule())
|
||||
def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
@ -953,6 +673,7 @@ class ElementwiseAbsModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.abs(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAbsModule())
|
||||
def ElementwiseAbsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0))
|
||||
|
@ -971,6 +692,7 @@ class ElementwiseReciprocalModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return torch.reciprocal(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseReciprocalModule())
|
||||
def ElementwiseReciprocalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4))
|
||||
|
@ -994,6 +716,7 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
|||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1013,10 +736,8 @@ class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
|||
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAndIntegerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1037,6 +758,7 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
|||
torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
||||
torch.randint(-10, 10, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSubScalarIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1055,6 +777,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
|||
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1068,10 +791,13 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.sub(x, 2.1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSubScalarFloatModule())
|
||||
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1084,10 +810,12 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.add(x, 3.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module())
|
||||
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseAddScalarIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1101,10 +829,12 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.add(x, 3.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
|
||||
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1123,6 +853,7 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
|||
def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseCloneModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -0,0 +1,459 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 0.6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule())
|
||||
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGtIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 10)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
|
||||
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
|
||||
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 0.6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeFloatScalarModule())
|
||||
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGeIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 10)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeIntScalarModule())
|
||||
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeMixedIntScalarModule())
|
||||
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeFloatIntScalarModule())
|
||||
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtFloatTensorModule())
|
||||
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseGtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
|
||||
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0.6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtFloatScalarModule())
|
||||
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLtIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
|
||||
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
||||
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 0.6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeFloatScalarModule())
|
||||
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 10)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeIntScalarModule())
|
||||
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeMixedIntScalarModule())
|
||||
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeFloatIntScalarModule())
|
||||
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtFloatTensorModule())
|
||||
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseLtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
|
||||
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 6.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqFloatScalarModule())
|
||||
def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseEqIntScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
|
||||
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (5, 8)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
||||
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqFloatTensorModule())
|
||||
def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
|
||||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
|
||||
|
|
@ -42,6 +42,7 @@ from . import matmul
|
|||
from . import reshape_like
|
||||
from . import scalar
|
||||
from . import scalar_comparison
|
||||
from . import elementwise_comparison
|
||||
from . import squeeze
|
||||
from . import slice_like
|
||||
from . import nll_loss
|
||||
|
|
|
@ -904,6 +904,36 @@ def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [
|
|||
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenLeScalarOp : Torch_Op<"aten.le.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::le.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenLe_ScalarOp : Torch_Op<"aten.le_.Scalar", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::le_.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` qualified(type($self)) `,` qualified(type($other)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -1990,6 +1990,36 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
|
||||
Type dtype = geScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||
|
||||
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
|
||||
// can be refactored.
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor args from integer to float.
|
||||
geScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
geScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
|
||||
Type dtype = eqScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||
Value otherPromoted =
|
||||
|
@ -2040,6 +2070,34 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
|
||||
Type dtype = leScalar.self().getType().cast<BaseTensorType>().getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that
|
||||
// can be refactored.
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
leScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
leScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
|
||||
Type dtype = converter->convertType(whereSelf.getType())
|
||||
.cast<RankedTensorType>()
|
||||
|
@ -2455,10 +2513,11 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
|
||||
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(op))
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -4575,9 +4634,9 @@ public:
|
|||
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp>();
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -235,7 +235,7 @@ public:
|
|||
|
||||
// These comparison ops return a tensor with 1-bit integer dtype.
|
||||
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
||||
AtenNeScalarOp>(op)) {
|
||||
AtenLeScalarOp, AtenNeScalarOp>(op)) {
|
||||
auto operand = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
|
|
|
@ -475,6 +475,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
|
||||
|
|
Loading…
Reference in New Issue