[MLIR][TORCH] Fix indentation and spacing for E2E tests

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1563/merge snapshot-20221124.667
Vivek Khandelwal 2022-11-23 12:26:06 +05:30
parent e45ad313d4
commit 3790a4270e
32 changed files with 2496 additions and 1081 deletions

View File

@ -10,6 +10,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
}
def register_all_tests():
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
# Side-effecting import statements.

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class ArangeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -20,16 +21,17 @@ class ArangeIntModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(5)
@register_test_case(module_factory=lambda: ArangeIntModule())
def ArangeIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -37,16 +39,17 @@ class ArangeFloatModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(5.0)
@register_test_case(module_factory=lambda: ArangeFloatModule())
def ArangeFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeZeroElementOutputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -54,16 +57,17 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(0)
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -71,16 +75,17 @@ class ArangeStartIntModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(0, 5)
@register_test_case(module_factory=lambda: ArangeStartIntModule())
def ArangeStartIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -88,16 +93,17 @@ class ArangeStartFloatModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(0.0, 5.0)
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
def ArangeStartFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeNegativeStartIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -105,16 +111,17 @@ class ArangeNegativeStartIntModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(-10, 5)
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeNegativeStartFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -122,16 +129,17 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1.4, 5.7)
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartStepIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -139,16 +147,17 @@ class ArangeStartStepIntModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(0, 5, 1)
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartStepFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -156,16 +165,17 @@ class ArangeStartStepFloatModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, 5, 1.3)
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartNegativeStepIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -173,16 +183,17 @@ class ArangeStartNegativeStepIntModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(10, 1, -2)
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -190,16 +201,18 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, -15, -3.4)
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
@register_test_case(
module_factory=lambda: ArangeStartNegativeStepFloatModule())
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeDtypeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -207,16 +220,17 @@ class ArangeDtypeFloatModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(-1, 15, dtype=torch.float32)
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
module.forward()
class ArangeDtypeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -224,16 +238,17 @@ class ArangeDtypeIntModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(0.2, 5.0, dtype=torch.int64)
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
module.forward()
class ArangeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -241,10 +256,10 @@ class ArangeFalsePinMemoryModule(torch.nn.Module):
@annotate_args([
None,
])
def forward(self):
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward()

View File

@ -10,7 +10,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -19,7 +21,6 @@ class ArgmaxModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a)
@ -28,27 +29,35 @@ class ArgmaxModule(torch.nn.Module):
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.argmax(a, dim=1)
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -60,6 +69,7 @@ class ArgmaxKeepDimsModule(torch.nn.Module):
def forward(self, a):
return torch.argmax(a, 0, True)
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

View File

@ -20,8 +20,10 @@ class SoftmaxBackwardModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, grad_output, output):
return torch.ops.aten._softmax_backward_data(grad_output,
@ -58,6 +60,7 @@ def TanhBackward_basic(module, tu: TestUtils):
# ==============================================================================
class ConvolutionBackwardModule2D(torch.nn.Module):
def __init__(self):
@ -66,9 +69,18 @@ class ConvolutionBackwardModule2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, grad_out, input_vec, weight):
return torch.ops.aten.convolution_backward(
@ -100,9 +112,18 @@ class ConvolutionBackwardModule2DPadded(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, grad_out, input_vec, weight):
return torch.ops.aten.convolution_backward(
@ -157,8 +178,10 @@ class LogSoftmaxBackwardModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, grad_output, output):
return torch.ops.aten._log_softmax_backward_data(grad_output,

View File

@ -49,8 +49,10 @@ class BmmModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.bmm(lhs, rhs)
@ -109,14 +111,16 @@ def IsFloatingPointFloat_True(module, tu: TestUtils):
class ContainsIntList(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None
])
@annotate_args([None])
def forward(self):
return torch.ops.aten.__contains__([1, 2, 3], 3)
@register_test_case(module_factory=lambda: ContainsIntList())
def ContainsIntList_True(module, tu: TestUtils):
module.forward()
@ -126,14 +130,16 @@ def ContainsIntList_True(module, tu: TestUtils):
class ContainsIntListFalse(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None
])
@annotate_args([None])
def forward(self):
return torch.ops.aten.__contains__([1, 2, 3], 4)
@register_test_case(module_factory=lambda: ContainsIntListFalse())
def ContainsIntList_False(module, tu: TestUtils):
module.forward()
@ -320,7 +326,10 @@ class FlattenDynamicModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, 9, 3, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
9, 3, -9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@ -365,7 +374,10 @@ class PadModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
pad = [0, 1, 2, 3]
@ -389,7 +401,10 @@ class PadWithNoneValModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
pad = [0, 1, 2, 3]
@ -413,7 +428,10 @@ class ConstantPadNdModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808, -9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
@ -457,7 +475,8 @@ class ConstantPadNdPartialStaticModule(torch.nn.Module):
@export
@annotate_args([
None,
([1, 1, 20, 20, -9223372036854775808, -9223372036854775808], torch.float32, True),
([1, 1, 20, 20, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf'))
@ -580,9 +599,12 @@ class TensorsConcatModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x, y, z):
return torch.cat([x, y, z], 1)
@ -604,9 +626,12 @@ class TensorsConcatNegativeDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x, y, z):
return torch.cat([x, y, z], dim=-2)
@ -628,8 +653,10 @@ class GatherModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 2, indices)
@ -644,18 +671,22 @@ def GatherModule_basic(module, tu: TestUtils):
class GatherRandomIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 1, indices)
@register_test_case(module_factory=lambda: GatherRandomIndexModule())
def GatherRandomIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), tu.randint(2, 3, 4, high=3))
@ -665,6 +696,7 @@ def GatherRandomIndexModule_basic(module, tu: TestUtils):
class Gather2DInputModdule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -677,6 +709,7 @@ class Gather2DInputModdule(torch.nn.Module):
def forward(self, tensor, indices):
return torch.gather(tensor, 1, indices)
@register_test_case(module_factory=lambda: Gather2DInputModdule())
def Gather2DInputModdule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5), torch.tensor([[1, 2, 3], [4, 3, 2]]))
@ -809,6 +842,7 @@ def EmbeddingModuleI32_basic(module, tu: TestUtils):
# ==============================================================================
class EmbeddingModuleI32Static(torch.nn.Module):
def __init__(self):
@ -870,7 +904,8 @@ class SoftmaxIntModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, tensor):
return self.softmax.forward(tensor)
@ -892,7 +927,8 @@ class _SoftmaxModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten._softmax(tensor, 0, False)
@ -916,7 +952,8 @@ class SoftmaxIntNegDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, tensor):
return self.softmax.forward(tensor)
@ -940,7 +977,8 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, tensor):
return self.softmax.forward(tensor)
@ -962,7 +1000,8 @@ class _LogSoftmaxModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False)
@ -1129,10 +1168,12 @@ class BroadcastZeroRankInputStaticModule(torch.nn.Module):
return torch.ops.aten.sub(x, y)
@register_test_case(module_factory=lambda: BroadcastZeroRankInputStaticModule())
@register_test_case(
module_factory=lambda: BroadcastZeroRankInputStaticModule())
def BroadcastZeroRankInputStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 8), tu.rand())
# ==============================================================================
@ -1154,6 +1195,7 @@ class RollModule(torch.nn.Module):
def RollModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))
# ==============================================================================
@ -1175,6 +1217,7 @@ class RepeatModule(torch.nn.Module):
def RepeatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))
# ==============================================================================
@ -1231,7 +1274,8 @@ class LogSoftmaxIntModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, tensor):
return self.log_softmax.forward(tensor)
@ -1479,7 +1523,8 @@ class NumelModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.numel(input)
@ -1772,7 +1817,8 @@ class IndexTensorModule3dInput(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.int64, True),
])
def forward(self, x, index):
@ -1795,7 +1841,8 @@ class IndexTensorSelectDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.int64, True),
])
def forward(self, a, ind):
@ -1806,6 +1853,7 @@ class IndexTensorSelectDimModule(torch.nn.Module):
def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6), tu.randint(2, 3, high=3))
# ==============================================================================
@ -1817,17 +1865,22 @@ class IndexTensorMultiInput(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([3, 3], torch.int64, True),
([3], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (index1, index2,))
return torch.ops.aten.index(x, (
index1,
index2,
))
@register_test_case(module_factory=lambda: IndexTensorMultiInput())
def IndexTensorMultiInput_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3), tu.randint(3, high=3))
module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3),
tu.randint(3, high=3))
# ==============================================================================
@ -1841,17 +1894,22 @@ class IndexTensorMultiInputOneDim(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([6, 1], torch.int64, True),
([3], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (index1, index2,))
return torch.ops.aten.index(x, (
index1,
index2,
))
@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim())
def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3))
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4),
tu.randint(3, high=3))
# ==============================================================================
@ -1865,7 +1923,8 @@ class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, 1], torch.int64, True),
([-9223372036854775808], torch.int64, True),
])
@ -1895,7 +1954,8 @@ class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, 1], torch.int64, True),
([-9223372036854775808], torch.int64, True),
])
@ -1926,7 +1986,8 @@ class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, 2], torch.int64, True),
([-9223372036854775808], torch.int64, True),
])
@ -1956,7 +2017,10 @@ class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([4, 1], torch.int64, True),
([1, 3], torch.int64, True),
([-9223372036854775808, 3], torch.int64, True),
@ -1984,7 +2048,10 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([4, 2], torch.int64, True),
([4, 2], torch.int64, True),
])
@ -1992,9 +2059,11 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module):
return torch.ops.aten.index(x, (index1, None, index2))
@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous())
@register_test_case(
module_factory=lambda: IndexTensorMultiInputNonContiguous())
def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1))
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3),
tu.randint(4, 2, high=1))
# ==============================================================================
@ -2008,21 +2077,24 @@ class IndexTensorMultiInputThreeIndexers(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808, -9223372036854775808
], torch.float32, True),
([8, 4, 2], torch.int64, True),
([8, 1, 1], torch.int64, True),
([4, 2], torch.int64, True),
])
def forward(self, x, index1, index2, index3):
return torch.ops.aten.index(x, (None, None, index1, None, index2, index3))
return torch.ops.aten.index(x,
(None, None, index1, None, index2, index3))
@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers())
@register_test_case(
module_factory=lambda: IndexTensorMultiInputThreeIndexers())
def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 4, 4, 5, 3),
tu.randint(8, 4, 2, high=3),
tu.randint(8, 1, 1, high=4),
tu.randint(4, 2, high=2))
module.forward(tu.rand(1, 2, 4, 4, 5, 3), tu.randint(8, 4, 2, high=3),
tu.randint(8, 1, 1, high=4), tu.randint(4, 2, high=2))
# ==============================================================================
@ -2036,7 +2108,10 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([2, 2], torch.int64, True),
([2], torch.int64, True),
])
@ -2044,9 +2119,11 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
return torch.ops.aten.index(x, (None, index1, index2, None))
@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter())
@register_test_case(
module_factory=lambda: IndexTensorMultiInputContiguousCenter())
def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3), tu.randint(2, high=2))
module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3),
tu.randint(2, high=2))
# ==============================================================================
@ -2083,7 +2160,8 @@ class IndexTensorHackedTwinModule3dInput(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.int64, True),
])
def forward(self, x, index):
@ -2108,7 +2186,10 @@ class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims(
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([4, 1], torch.int64, True),
([1, 3], torch.int64, True),
([-9223372036854775808, 3], torch.int64, True),
@ -2137,7 +2218,8 @@ class SquareModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.square(x)
@ -2336,7 +2418,8 @@ class ExpandAsFloatModule(torch.nn.Module):
@annotate_args([
None,
([-9223372036854775808, 1, 1], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.expand_as(x, y)
@ -2356,7 +2439,8 @@ class ExpandAsIntModule(torch.nn.Module):
@annotate_args([
None,
([1, 1, 1], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.expand_as(x, y)
@ -2364,8 +2448,8 @@ class ExpandAsIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ExpandAsIntModule())
def ExpandAsIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 1, 1, high=100),
tu.randint(4, 5, 6, high=200))
module.forward(tu.randint(1, 1, 1, high=100), tu.randint(4, 5, 6,
high=200))
# ==============================================================================
@ -2379,8 +2463,10 @@ class CopyModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.copy_(x, y)
@ -2419,8 +2505,10 @@ class CopyWithDifferentDTypesModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.copy_(x, y)
@ -2463,7 +2551,8 @@ class ToCopyModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten._to_copy(x)
@ -2482,7 +2571,8 @@ class ToCopyWithDTypeModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten._to_copy(x, dtype=torch.int64)
@ -2501,7 +2591,8 @@ class ToCopyWithDTypeFalsePinMemoryModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten._to_copy(x, dtype=torch.int64, pin_memory=False)
@ -2543,7 +2634,8 @@ class FlipModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.flip(x, [1, 2])
@ -2565,7 +2657,8 @@ class DetachModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.detach(x)
@ -2650,9 +2743,12 @@ class BaddbmmDynamicModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, batch1, batch2):
return torch.ops.aten.baddbmm(input, batch1, batch2)
@ -2692,9 +2788,12 @@ class BaddbmmDifferentDtypesModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, batch1, batch2):
return torch.ops.aten.baddbmm(input, batch1, batch2)
@ -2714,9 +2813,12 @@ class BaddbmmWithAlphaModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, batch1, batch2):
return torch.ops.aten.baddbmm(input, batch1, batch2, alpha=5)
@ -2735,9 +2837,12 @@ class BaddbmmWithBetaModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, batch1, batch2):
return torch.ops.aten.baddbmm(input, batch1, batch2, beta=0.5)
@ -2756,9 +2861,12 @@ class BaddbmmWithAlphaBetaModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, batch1, batch2):
return torch.ops.aten.baddbmm(input, batch1, batch2, beta=6, alpha=2.4)
@ -2841,7 +2949,10 @@ class NumpyTRankNDynamicModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808
], torch.float32, True),
])
def forward(self, lhs):
return torch.ops.aten.numpy_T(lhs)
@ -2908,6 +3019,7 @@ class NumpyTRank0Module(torch.nn.Module):
def NumpyTRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(7, dtype=torch.float32))
class AtenEmbeddingBagSumExample(torch.nn.Module):
def __init__(self):
@ -2921,15 +3033,26 @@ class AtenEmbeddingBagSumExample(torch.nn.Module):
([-9223372036854775808], torch.int64, True),
])
def forward(self, weight, indices, offsets):
return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None)
return torch.ops.aten.embedding_bag(weight,
indices,
offsets,
scale_grad_by_freq=False,
mode=0,
sparse=False,
per_sample_weights=None,
include_last_offset=False,
padding_idx=None)
@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample())
def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils):
weight = torch.rand(100, 10)
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
indices = torch.LongTensor(
[0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
module.forward(weight, indices, offsets)
class Aten_EmbeddingBagExample(torch.nn.Module):
def __init__(self):
@ -2945,15 +3068,19 @@ class Aten_EmbeddingBagExample(torch.nn.Module):
def forward(self, weight, indices, offsets):
return torch.ops.aten._embedding_bag(weight, indices, offsets)
@register_test_case(module_factory=lambda: Aten_EmbeddingBagExample())
def Aten_EmbeddingBagExample_basic(module, tu: TestUtils):
weight = torch.rand(100, 10)
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
indices = torch.LongTensor(
[0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
module.forward(weight, indices, offsets)
# ==============================================================================
class CumsumModule(torch.nn.Module):
def __init__(self):
@ -2962,15 +3089,18 @@ class CumsumModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
@register_test_case(module_factory=lambda: CumsumModule())
def CumsumModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
class CumsumStaticModule(torch.nn.Module):
def __init__(self):
@ -2984,13 +3114,17 @@ class CumsumStaticModule(torch.nn.Module):
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
@register_test_case(module_factory=lambda: CumsumStaticModule())
def CumsumStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7, 4))
# ==============================================================================
class AtenToDeviceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2999,14 +3133,18 @@ class AtenToDeviceModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, val):
return torch.ops.aten.to(val, device='cpu', dtype=torch.float, non_blocking=False)
return torch.ops.aten.to(val,
device='cpu',
dtype=torch.float,
non_blocking=False)
@register_test_case(module_factory=lambda: AtenToDeviceModule())
def AtenToDeviceModule_basic(module, tu: TestUtils):
module.forward(torch.randn(2, 4))
# ==============================================================================
@ -3018,10 +3156,14 @@ class UpSampleNearest2dBackward(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float64, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
return torch.ops.aten.upsample_nearest2d_backward(
input,
output_size=[6, 12],
input_size=[1, 1, 2, 3],
scales_h=3.0,
@ -3041,16 +3183,22 @@ class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
return torch.ops.aten.upsample_nearest2d_backward(
input,
output_size=[4, 8],
input_size=[1, 1, 2, 3],
scales_h=None,
scales_w=None)
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
@register_test_case(
module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 8))

View File

@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class TensorToIntZeroRank(torch.nn.Module):
def __init__(self):
super().__init__()
@ -28,9 +30,12 @@ class TensorToIntZeroRank(torch.nn.Module):
def TensorToIntZeroRank_basic(module, tu: TestUtils):
module.forward(tu.randint(high=10))
# ==============================================================================
class TensorToInt(torch.nn.Module):
def __init__(self):
super().__init__()
@ -47,9 +52,12 @@ class TensorToInt(torch.nn.Module):
def TensorToInt_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 1, high=10))
# ==============================================================================
class TensorToFloatZeroRank(torch.nn.Module):
def __init__(self):
super().__init__()
@ -66,9 +74,12 @@ class TensorToFloatZeroRank(torch.nn.Module):
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
module.forward(torch.rand((), dtype=torch.float64))
# ==============================================================================
class TensorToFloat(torch.nn.Module):
def __init__(self):
super().__init__()
@ -85,9 +96,12 @@ class TensorToFloat(torch.nn.Module):
def TensorToFloat_basic(module, tu: TestUtils):
module.forward(torch.rand((1, 1), dtype=torch.float64))
# ==============================================================================
class TensorToBoolZeroRank(torch.nn.Module):
def __init__(self):
super().__init__()
@ -104,9 +118,12 @@ class TensorToBoolZeroRank(torch.nn.Module):
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
module.forward(torch.tensor(1, dtype=torch.bool))
# ==============================================================================
class TensorToBool(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -357,7 +357,10 @@ class EmptyLikeMemoryFormatModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, a):
return torch.empty_like(a,
@ -396,7 +399,8 @@ class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.empty_like(a, dtype=torch.float64,
@ -476,7 +480,8 @@ class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
@ -555,7 +560,8 @@ class OnesLikeFalsePinMemoryModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
@ -596,7 +602,8 @@ class NewZerosModuleInt2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64)
@ -634,7 +641,8 @@ class NewZerosModuleFloat2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32)
@ -715,7 +723,8 @@ class NewOnesModuleInt2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64)
@ -753,7 +762,8 @@ class NewOnesModuleFloat2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32)
@ -967,7 +977,8 @@ class FullLikeModuleInt3D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 5.0, dtype=torch.int64)
@ -1024,7 +1035,8 @@ class FullLikeModuleFloat3D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.full_like(a, 15, dtype=torch.float32)
@ -1166,7 +1178,8 @@ class NewEmptyModuleInt2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0)
@ -1205,7 +1218,8 @@ class NewEmptyModuleFloat2D(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.new_empty(a, [3, 4],

View File

@ -13,14 +13,15 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class TorchPrimLoopForLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.int64, True)
None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)
])
def forward(self, x):
x_val = x.size(0)
@ -34,15 +35,16 @@ class TorchPrimLoopForLikeModule(torch.nn.Module):
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(6, 8, high=10))
# ==============================================================================
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.int64, True)
None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)
])
def forward(self, x):
x_val = x.size(0)

View File

@ -22,7 +22,10 @@ class Conv2dNoPaddingModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@ -45,7 +48,10 @@ class Conv2dBiasNoPaddingModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@ -68,7 +74,10 @@ class Conv2dWithPaddingModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@ -97,7 +106,10 @@ class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@ -142,15 +154,23 @@ def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class Convolution2DModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
@ -163,12 +183,14 @@ class Convolution2DModule(torch.nn.Module):
output_padding=[0, 0],
groups=1)
@register_test_case(module_factory=lambda: Convolution2DModule())
def Convolution2DModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class Convolution2DStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -189,19 +211,28 @@ class Convolution2DStaticModule(torch.nn.Module):
output_padding=[0, 0],
groups=1)
@register_test_case(module_factory=lambda: Convolution2DStaticModule())
def Convolution2DStaticModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class Convolution2DStridedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
@ -214,19 +245,28 @@ class Convolution2DStridedModule(torch.nn.Module):
output_padding=[0, 0],
groups=1)
@register_test_case(module_factory=lambda: Convolution2DStridedModule())
def Convolution2DStridedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _Convolution2DAllFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -243,19 +283,28 @@ class _Convolution2DAllFalseModule(torch.nn.Module):
cudnn_enabled=False,
allow_tf32=False)
@register_test_case(module_factory=lambda: _Convolution2DAllFalseModule())
def _Convolution2DAllFalseModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _Convolution2DBenchmarkModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -272,19 +321,28 @@ class _Convolution2DBenchmarkModule(torch.nn.Module):
cudnn_enabled=False,
allow_tf32=False)
@register_test_case(module_factory=lambda: _Convolution2DBenchmarkModule())
def _Convolution2DBenchmarkModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _Convolution2DDeterministicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -301,19 +359,28 @@ class _Convolution2DDeterministicModule(torch.nn.Module):
cudnn_enabled=False,
allow_tf32=False)
@register_test_case(module_factory=lambda: _Convolution2DDeterministicModule())
def _Convolution2DDeterministicModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _Convolution2DCudnnModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -330,19 +397,28 @@ class _Convolution2DCudnnModule(torch.nn.Module):
cudnn_enabled=True,
allow_tf32=False)
@register_test_case(module_factory=lambda: _Convolution2DCudnnModule())
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _Convolution2DTF32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -359,19 +435,28 @@ class _Convolution2DTF32Module(torch.nn.Module):
cudnn_enabled=False,
allow_tf32=True)
@register_test_case(module_factory=lambda: _Convolution2DTF32Module())
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -387,19 +472,29 @@ class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
deterministic=False,
cudnn_enabled=False)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
@register_test_case(
module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -415,19 +510,29 @@ class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
deterministic=False,
cudnn_enabled=False)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
@register_test_case(
module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -443,19 +548,29 @@ class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
deterministic=True,
cudnn_enabled=False)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
@register_test_case(
module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten._convolution(inputVec,
@ -471,19 +586,29 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
deterministic=False,
cudnn_enabled=True)
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
@register_test_case(
module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
class ConvolutionModule2DGroups(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
@ -496,12 +621,15 @@ class ConvolutionModule2DGroups(torch.nn.Module):
output_padding=[0, 0],
groups=4)
@register_test_case(module_factory=lambda: ConvolutionModule2DGroups())
def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3))
# ==============================================================================
class ConvolutionModule2DTranspose(torch.nn.Module):
def __init__(self):
@ -510,8 +638,14 @@ class ConvolutionModule2DTranspose(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
@ -529,6 +663,7 @@ class ConvolutionModule2DTranspose(torch.nn.Module):
def ConvolutionModule2DTranspose_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2))
class ConvolutionModule2DTransposeStrided(torch.nn.Module):
def __init__(self):
@ -537,8 +672,14 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
@ -552,10 +693,12 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module):
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided())
@register_test_case(
module_factory=lambda: ConvolutionModule2DTransposeStrided())
def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils):
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module):
def __init__(self):
@ -579,7 +722,8 @@ class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module):
groups=1)
@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic())
@register_test_case(
module_factory=lambda: ConvolutionModule2DTransposeStridedStatic())
def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils):
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
@ -592,8 +736,14 @@ class Conv_Transpose2dModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose2d(inputVec,
@ -619,7 +769,10 @@ class UpSampleNearest2d(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float64, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d(input,
@ -632,6 +785,7 @@ class UpSampleNearest2d(torch.nn.Module):
def UpSampleNearest2d_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
class UpSampleNearest2dSameSize(torch.nn.Module):
def __init__(self):
@ -640,7 +794,10 @@ class UpSampleNearest2dSameSize(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
@ -660,7 +817,13 @@ class UpSampleNearest2dDiffSize(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True)
])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
output_size=[8, 11],
@ -679,7 +842,13 @@ class UpSampleNearest2dDiffFactor(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True)
])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
output_size=[6, 10],
@ -700,7 +869,10 @@ class UpSampleNearest2dSameFactor(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,

View File

@ -17,7 +17,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# the PyTorch op registry permanently.
import torch_mlir._torch_mlir_custom_op_example
class CustomOpExampleModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -118,7 +118,8 @@ class ElementwiseTernaryModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
@ -152,7 +153,8 @@ class ElementwiseAtenWhereSelfModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule())
def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), torch.rand(1, 12, 5, 5), torch.rand(()))
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool),
torch.rand(1, 12, 5, 5), torch.rand(()))
# ==============================================================================
@ -166,7 +168,8 @@ class ElementwiseWhereSelfModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
@ -190,7 +193,8 @@ class ElementwiseWhereScalarModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.where(a > 0.5, 4.0, 8.0)
@ -212,7 +216,8 @@ class ElementwiseWhereScalarOtherModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808], torch.float64, True),
])
def forward(self, a, b):
@ -235,7 +240,8 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808], torch.float64, True),
])
def forward(self, a, b):
@ -913,7 +919,8 @@ class ElementwiseAtan2TensorIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10))
tu.randint(4, low=1, high=10).type(torch.int32),
tu.randint(4, low=1, high=10))
# ==============================================================================
@ -936,7 +943,8 @@ class ElementwiseAtan2FloatIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 4, low=1, high=10).to(torch.int32),
module.forward(
tu.randint(4, 4, low=1, high=10).to(torch.int32),
tu.rand(4, 4).double())
@ -983,6 +991,7 @@ class ElementwiseLogIntModule(torch.nn.Module):
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
# ==============================================================================
@ -1200,7 +1209,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
return torch.pow(a, b)
@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule())
@register_test_case(
module_factory=lambda: ElementwisePowTensorBroadcastModule())
def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1), tu.rand(3, 4))
@ -1214,7 +1224,10 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True)
])
def forward(self, x):
return x.to(torch.int64)
@ -1233,7 +1246,10 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True)
])
def forward(self, x):
return x.to(torch.float32, False, False)
@ -1342,7 +1358,8 @@ class ElementwiseAbsModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.abs(a)
@ -1418,6 +1435,7 @@ class ElementwiseDivScalarModule(torch.nn.Module):
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
@ -1435,7 +1453,8 @@ class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
return torch.remainder(x, 2.0)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
module.forward(tu.randint(3, high=10).to(torch.int32))
@ -1457,13 +1476,15 @@ class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
return torch.remainder(x, 2.0)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float())
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Float())
def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
module.forward(torch.rand(10, 3))
# ==============================================================================
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
def __init__(self):
@ -1478,12 +1499,15 @@ class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
return torch.remainder(x, 2)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int())
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Int())
def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 2, high=10).to(torch.int32))
# ==============================================================================
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
def __init__(self):
@ -1498,7 +1522,8 @@ class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
return torch.remainder(x, 2)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool())
@register_test_case(
module_factory=lambda: ElementwiseRemainderScalarModule_Bool())
def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
module.forward(torch.tensor([True, False, True, True, True]))
@ -1786,7 +1811,8 @@ class ElementwiseCloneModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.clone(x)
@ -1808,7 +1834,8 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.clone(x, memory_format=torch.contiguous_format)
@ -1830,7 +1857,8 @@ class LiftFreshCopyModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.lift_fresh_copy(x)
@ -2038,9 +2066,12 @@ class ElementwiseNegModule(torch.nn.Module):
def ElementwiseNegModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseAtenLogicalOrOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2053,11 +2084,14 @@ class ElementwiseAtenLogicalOrOpModule(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpModule())
def ElementwiseAtenLogicalOrOpModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([False, True]), torch.tensor([False, False]))
class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2070,13 +2104,18 @@ class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module())
def ElementwiseAtenLogicalOrOpDiffArgs1Module_basic(module, tu: TestUtils):
module.forward(torch.tensor([0.2, 0.1]), torch.tensor([0, 1]))
# ==============================================================================
class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2089,13 +2128,18 @@ class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module())
def ElementwiseAtenLogicalOrOpDiffArgs2Module_basic(module, tu: TestUtils):
module.forward(torch.tensor([True, False]), torch.tensor([0, 1]))
# ==============================================================================
class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2108,70 +2152,110 @@ class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module())
def ElementwiseAtenLogicalOrOpDiffArgs3Module_basic(module, tu: TestUtils):
module.forward(torch.tensor([1, 2]), torch.tensor([False, True]))
# ==============================================================================
class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.int64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule())
def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100))
module.forward(tu.randint(2, 3, 4, 5, low=3, high=10),
tu.randint(2, 3, 4, 5, low=10, high=100))
# ==============================================================================
class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule())
def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils):
module.forward(torch.rand(2, 3, 3, 5), torch.rand(2, 3, 3, 5))
# ==============================================================================
class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.int64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule())
def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils):
module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100)))
module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)),
torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100)))
# ==============================================================================
class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2184,7 +2268,9 @@ class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.logical_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule())
@register_test_case(
module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule())
def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, high=3), tu.randint(4, 3, high=3))
@ -2244,7 +2330,10 @@ class AtenTriuModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.triu(x)
@ -2266,7 +2355,8 @@ class AtenTriuWithPosDiagonalModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.triu(x, diagonal=2)
@ -2288,7 +2378,10 @@ class AtenTriuWithNegDiagonalModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.triu(x, diagonal=-4)
@ -2298,6 +2391,7 @@ class AtenTriuWithNegDiagonalModule(torch.nn.Module):
def AtenTriuWithNegDiagonalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 5, 9))
# ==============================================================================
@ -2319,6 +2413,7 @@ class AtenRoundFloatModule(torch.nn.Module):
def AtenRoundFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
class AtenRoundIntModule(torch.nn.Module):
def __init__(self):
@ -2349,7 +2444,8 @@ class Fill_TensorFloat64WithFloat32(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@ -2368,7 +2464,8 @@ class Fill_TensorFloat64WithFloat64(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3.0)
@ -2387,7 +2484,8 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, tensor):
return torch.ops.aten.fill_(tensor, 3)
@ -2409,12 +2507,14 @@ class Fill_TensorFloat32WithFloat32(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([], torch.float32, True),
])
def forward(self, tensor, value):
return torch.ops.aten.fill_(tensor, value)
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat32())
def Fill_TensorFloat32WithFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand())
@ -2428,12 +2528,14 @@ class Fill_TensorFloat32WithFloat64(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([], torch.float64, True),
])
def forward(self, tensor, value):
return torch.ops.aten.fill_(tensor, value)
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat64())
def Fill_TensorFloat32WithFloat64_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand().to(torch.float64))
@ -2447,12 +2549,14 @@ class Fill_TensorFloat32WithInt64(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([], torch.int64, True),
])
def forward(self, tensor, value):
return torch.ops.aten.fill_(tensor, value)
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithInt64())
def Fill_TensorFloat32WithInt64_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.randint())

View File

@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ElementwiseGtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -28,9 +30,12 @@ class ElementwiseGtFloatScalarModule(torch.nn.Module):
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseGtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -47,9 +52,12 @@ class ElementwiseGtIntScalarModule(torch.nn.Module):
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15))
# ==============================================================================
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -66,9 +74,12 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
# ==============================================================================
class ElementwiseGeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -85,9 +96,12 @@ class ElementwiseGeFloatScalarModule(torch.nn.Module):
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseGeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -104,9 +118,12 @@ class ElementwiseGeIntScalarModule(torch.nn.Module):
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15))
# ==============================================================================
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -123,9 +140,12 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
# ==============================================================================
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -142,9 +162,12 @@ class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -162,9 +185,12 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module):
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
# ==============================================================================
class ElementwiseGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -182,9 +208,12 @@ class ElementwiseGtIntTensorModule(torch.nn.Module):
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
# ==============================================================================
class ElementwiseLtFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -201,9 +230,12 @@ class ElementwiseLtFloatScalarModule(torch.nn.Module):
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -220,9 +252,12 @@ class ElementwiseLtIntScalarModule(torch.nn.Module):
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15))
# ==============================================================================
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -240,9 +275,12 @@ class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
# ==============================================================================
class ElementwiseLeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -259,9 +297,12 @@ class ElementwiseLeFloatScalarModule(torch.nn.Module):
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -278,9 +319,12 @@ class ElementwiseLeIntScalarModule(torch.nn.Module):
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15))
# ==============================================================================
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -297,9 +341,12 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
# ==============================================================================
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -316,9 +363,12 @@ class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
# ==============================================================================
class ElementwiseLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -336,9 +386,12 @@ class ElementwiseLtFloatTensorModule(torch.nn.Module):
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5))
# ==============================================================================
class ElementwiseLtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -356,9 +409,12 @@ class ElementwiseLtIntTensorModule(torch.nn.Module):
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
# ==============================================================================
class ElementwiseEqFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -376,9 +432,12 @@ def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
# ==============================================================================
class ElementwiseEqIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -395,9 +454,12 @@ class ElementwiseEqIntScalarModule(torch.nn.Module):
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5, 8, low=2, high=4))
# ==============================================================================
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -415,9 +477,12 @@ class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32))
# ==============================================================================
class ElementwiseEqFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -437,9 +502,12 @@ def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
# ==============================================================================
class ElementwiseEqIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -455,11 +523,16 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5,
low=2,
high=4))
# ==============================================================================
class ElementwiseNeFloatScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -477,9 +550,12 @@ def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
# ==============================================================================
class ElementwiseNeIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -496,9 +572,12 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(8, 5, low=2, high=4))
# ==============================================================================
class AnyBoolTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -517,6 +596,7 @@ def AnyBoolTrueModule_basic(module, tu: TestUtils):
class AnyBoolFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -533,8 +613,10 @@ class AnyBoolFalseModule(torch.nn.Module):
def AnyBoolFalseModule_basic(module, tu: TestUtils):
module.forward()
# =================================================================================
class AllBoolTrueModule(torch.nn.Module):
def __init__(self):
@ -553,8 +635,10 @@ class AllBoolTrueModule(torch.nn.Module):
def AllBoolTrueModule_basic(module, tu: TestUtils):
module.forward()
# =================================================================================
class AllBoolFalseModule(torch.nn.Module):
def __init__(self):
@ -568,6 +652,7 @@ class AllBoolFalseModule(torch.nn.Module):
input = [True, False, True, True, False]
return torch.ops.aten.all(input)
@register_test_case(module_factory=lambda: AllBoolFalseModule())
def AllBoolFalseModule_basic(module, tu: TestUtils):
module.forward()

View File

@ -70,9 +70,11 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index, ),
@ -84,8 +86,7 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule())
def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
tu.rand(5, 8, 6))
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
# ==============================================================================
@ -178,9 +179,11 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input.clone(), (index, ),
@ -192,8 +195,7 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: IndexPutImpl3DFloatAccumulateModule())
def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
tu.rand(5, 8, 6))
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
# ==============================================================================
@ -283,9 +285,11 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index, ),
@ -296,8 +300,7 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: IndexPut3DFloatNonAccumulateModule())
def IndexPut3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
tu.rand(5, 8, 6))
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
# ==============================================================================
@ -359,9 +362,11 @@ class IndexPut3DIntNonAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index, ),
@ -432,9 +437,11 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index, ),
@ -444,8 +451,7 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module):
@register_test_case(module_factory=lambda: IndexPut3DFloatAccumulateModule())
def IndexPut3DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
tu.rand(5, 8, 6))
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
# ==============================================================================
@ -507,9 +513,11 @@ class IndexPut3DIntAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index, ),
@ -583,9 +591,11 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, [index],
@ -596,8 +606,7 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule())
def IndexPutHackedTwin3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
tu.rand(5, 8, 6))
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
# ==============================================================================
@ -661,9 +670,11 @@ class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, [index],
@ -733,9 +744,11 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
@ -744,8 +757,7 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
@register_test_case(
module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule())
def IndexPutHackedTwin3DFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
tu.rand(5, 8, 6))
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
# ==============================================================================
@ -805,9 +817,11 @@ class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, [index], value, accumulate=True)

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class IndexSelectSingleIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -22,16 +23,17 @@ class IndexSelectSingleIdxModule(torch.nn.Module):
([4, 5, 6], torch.float32, True),
([1], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 1, indices)
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))
class IndexSelectTwoIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -41,16 +43,17 @@ class IndexSelectTwoIdxModule(torch.nn.Module):
([4, 5, 6], torch.float32, True),
([2], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 2, indices)
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))
class IndexSelectWholeDimensionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -60,16 +63,17 @@ class IndexSelectWholeDimensionModule(torch.nn.Module):
([4, 5, 6], torch.float32, True),
([4], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 0, indices)
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))
class IndexSelectWholeTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -79,54 +83,59 @@ class IndexSelectWholeTensorModule(torch.nn.Module):
([3], torch.float32, True),
([3], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 0, indices)
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))
class IndexSelectDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 2, indices)
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))
class IndexSelectDynamicInputSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([2], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 2, indices)
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))
class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -136,10 +145,10 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
([4, 5, 6], torch.float32, True),
([-9223372036854775808], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 1, indices)
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))

View File

@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class MatmulDot(torch.nn.Module):
def __init__(self):
super().__init__()
@ -29,9 +31,12 @@ class MatmulDot(torch.nn.Module):
def Matmul_dot(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3))
# ==============================================================================
class Matmul2D(torch.nn.Module):
def __init__(self):
super().__init__()
@ -49,9 +54,12 @@ class Matmul2D(torch.nn.Module):
def Matmul_2d(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(4, 5))
# ==============================================================================
class MatmulVecMat(torch.nn.Module):
def __init__(self):
super().__init__()
@ -69,9 +77,12 @@ class MatmulVecMat(torch.nn.Module):
def Matmul_vecmat(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4, 5))
# ==============================================================================
class MatmulMatVec(torch.nn.Module):
def __init__(self):
super().__init__()
@ -89,17 +100,22 @@ class MatmulMatVec(torch.nn.Module):
def Matmul_matvec(module, tu: TestUtils):
module.forward(tu.rand(4, 5), tu.rand(5))
# ==============================================================================
class Matmul3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@ -109,17 +125,26 @@ class Matmul3D(torch.nn.Module):
def Matmul_3d(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
# ==============================================================================
class Matmul4d(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@ -129,9 +154,12 @@ class Matmul4d(torch.nn.Module):
def Matmul_4d(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================
class Matmul4dStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@ -149,9 +177,12 @@ class Matmul4dStatic(torch.nn.Module):
def Matmul4dStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================
class MatmulStaticBroadcast(torch.nn.Module):
def __init__(self):
super().__init__()
@ -169,17 +200,22 @@ class MatmulStaticBroadcast(torch.nn.Module):
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
# ==============================================================================
class MatmulSingleDynamicBatchDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([4, -9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([4, -9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@ -189,17 +225,22 @@ class MatmulSingleDynamicBatchDim(torch.nn.Module):
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================
class MatmulBroadcastBatchDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([4, -9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@ -209,8 +250,10 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
# ==============================================================================
class Mv(torch.nn.Module):
@export

View File

@ -14,13 +14,16 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# Multi-layer perceptron (MLP) models.
class Mlp1LayerModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
torch.manual_seed(0)
self.fc0 = nn.Linear(3, 5)
self.tanh0 = nn.Tanh()
@export
@annotate_args([
None,
@ -29,11 +32,14 @@ class Mlp1LayerModule(torch.nn.Module):
def forward(self, x):
return self.tanh0(self.fc0(x))
@register_test_case(module_factory=lambda: Mlp1LayerModule())
def Mlp1LayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3))
class Mlp2LayerModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -43,6 +49,7 @@ class Mlp2LayerModule(torch.nn.Module):
self.tanh0 = nn.Tanh()
self.fc1 = nn.Linear(N_HIDDEN, 2)
self.tanh1 = nn.Tanh()
@export
@annotate_args([
None,
@ -53,11 +60,14 @@ class Mlp2LayerModule(torch.nn.Module):
x = self.tanh1(self.fc1(x))
return x
@register_test_case(module_factory=lambda: Mlp2LayerModule())
def Mlp2LayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3))
class Mlp2LayerModuleNoBias(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -67,6 +77,7 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
self.tanh0 = nn.Tanh()
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
self.tanh1 = nn.Tanh()
@export
@annotate_args([
None,
@ -77,25 +88,31 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
x = self.tanh1(self.fc1(x))
return x
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3))
class BatchMlpLayerModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
torch.manual_seed(0)
self.fc0 = nn.Linear(3, 5)
self.tanh0 = nn.Tanh()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return self.tanh0(self.fc0(x))
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
def BatchMlpLayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(7, 5, 3))

View File

@ -38,6 +38,7 @@ def NllLossModule_basic(module, tu: TestUtils):
class NllLossModule_mean(torch.nn.Module):
def __init__(self):
super().__init__()
@ -62,6 +63,7 @@ def NllLossModule_mean_basic(module, tu: TestUtils):
class NllLossModule_sum(torch.nn.Module):
def __init__(self):
super().__init__()
@ -86,6 +88,7 @@ def NllLossModule_sum_basic(module, tu: TestUtils):
class NllLossModule_1D(torch.nn.Module):
def __init__(self):
super().__init__()
@ -129,10 +132,12 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
ignore_index=10)[0]
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
@register_test_case(
module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
class NllLossModule_backward(torch.nn.Module):
def __init__(self):
@ -192,7 +197,6 @@ def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
torch.rand(4), torch.tensor(3.))
class NllLossModule_backward_ignore_index(torch.nn.Module):
def __init__(self):
@ -453,7 +457,8 @@ class NllLossModule_backward1DMeanWeight(torch.nn.Module):
total_weight=total_weight)
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
@register_test_case(
module_factory=lambda: NllLossModule_backward1DMeanWeight())
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
torch.rand(3), torch.tensor(3.))

View File

@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class BatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn1d = torch.nn.BatchNorm1d(4)
@ -35,9 +37,12 @@ class BatchNorm1DModule(torch.nn.Module):
def BatchNorm1DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 3))
# ==============================================================================
class BatchNorm1DWith2DInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn1d = torch.nn.BatchNorm1d(4)
@ -61,9 +66,12 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module):
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4))
# ==============================================================================
class BatchNorm2DModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn2d = torch.nn.BatchNorm2d(2)
@ -86,9 +94,12 @@ class BatchNorm2DModule(torch.nn.Module):
def BatchNorm2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 2, 3, 3))
# ==============================================================================
class BatchNorm3DModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn3d = torch.nn.BatchNorm3d(5)
@ -113,111 +124,156 @@ class BatchNorm3DModule(torch.nn.Module):
def BatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 6, 4))
# ==============================================================================
class NativeBatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)
return torch.ops.aten.native_batch_norm(x,
weight,
bias,
running_mean,
running_var,
training=False,
momentum=0.1,
eps=0.00001)
@register_test_case(module_factory=lambda: NativeBatchNorm1DModule())
def NativeBatchNorm1DModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
module.forward(tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5),
tu.rand(5))
# ==============================================================================
class NativeBatchNorm2DModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)
return torch.ops.aten.native_batch_norm(x,
weight,
bias,
running_mean,
running_var,
training=False,
momentum=0.1,
eps=0.00001)
@register_test_case(module_factory=lambda: NativeBatchNorm2DModule())
def NativeBatchNorm2DModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
module.forward(tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5),
tu.rand(5))
# ==============================================================================
class NativeBatchNorm3DModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808
], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
def forward(self, x, weight, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, weight, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)
return torch.ops.aten.native_batch_norm(x,
weight,
bias,
running_mean,
running_var,
training=False,
momentum=0.1,
eps=0.00001)
@register_test_case(module_factory=lambda: NativeBatchNorm3DModule())
def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5),
tu.rand(5))
# ==============================================================================
class NativeBatchNormNoneWeightModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808
], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
def forward(self, x, bias, running_mean, running_var):
return torch.ops.aten.native_batch_norm(
x, None, bias, running_mean, running_var, training=False,
momentum=0.1, eps=0.00001)
return torch.ops.aten.native_batch_norm(x,
None,
bias,
running_mean,
running_var,
training=False,
momentum=0.1,
eps=0.00001)
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
# ==============================================================================
class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -230,38 +286,46 @@ class NativeLayerNormModule(torch.nn.Module):
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
@register_test_case(module_factory=lambda: NativeLayerNormModule())
def NativeLayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
class NativeLayerNormDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808
], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
# ==============================================================================
class NativeLayerNormModule4D(torch.nn.Module):
def __init__(self):
super().__init__()
@ -274,17 +338,20 @@ class NativeLayerNormModule4D(torch.nn.Module):
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)[0]
return torch.ops.aten.native_layer_norm(x, list, weight, bias,
eps=0.5)[0]
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
# ==============================================================================
class LayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly = torch.nn.LayerNorm([2, 2, 3])
@ -309,9 +376,12 @@ class LayerNormModule(torch.nn.Module):
def LayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))
# ==============================================================================
class LayerNormLastDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly = torch.nn.LayerNorm([3])
@ -332,9 +402,12 @@ class LayerNormLastDimModule(torch.nn.Module):
def LayerNormLastDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))
# ==============================================================================
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ly = torch.nn.LayerNorm([2, 2, 3])
@ -355,6 +428,7 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
return self.ly(x)
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
@register_test_case(
module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 3))

View File

@ -43,7 +43,10 @@ class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@ -86,7 +89,10 @@ class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@ -113,7 +119,10 @@ class MaxPool2dModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.mp2d(x)
@ -160,7 +169,10 @@ class MaxPool2dCeilModeTrueModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.mp2d(x)
@ -182,7 +194,10 @@ class MaxPool2dWithIndicesModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -205,7 +220,10 @@ class MaxPool2dWithIndicesFullSizeKernelModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -229,7 +247,10 @@ class MaxPool2dWithIndicesNonDefaultPaddingModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -253,7 +274,10 @@ class MaxPool2dWithIndicesNonDefaultStrideModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -277,7 +301,10 @@ class MaxPool2dWithIndicesNonDefaultDilationModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -301,7 +328,10 @@ class MaxPool2dWithIndicesNonDefaultParamsModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -325,7 +355,10 @@ class MaxPool2dWithIndicesAllNegativeValuesModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -372,7 +405,10 @@ class MaxPool2dWithIndicesAllOnesModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -395,7 +431,10 @@ class MaxPool2dWithIndicesCeilModeTrueModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices(x,
@ -483,9 +522,18 @@ class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.int64, True),
])
def forward(self, output, input, indices):
kernel_size = [2, 2]
@ -513,9 +561,12 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, output, input, indices):
kernel_size = [2, 2]
@ -552,15 +603,20 @@ class AvgPool2dFloatModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dFloatModule())
def AvgPool2dFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=-1))
class AvgPool2dIntModule(torch.nn.Module):
def __init__(self):
@ -575,7 +631,10 @@ class AvgPool2dIntModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.int64, True),
])
def forward(self, x):
return self.ap2d(x)
@ -650,11 +709,15 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule())
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))

View File

@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class QuantizedMLP(nn.Module):
def __init__(self):
super().__init__()
torch.random.manual_seed(0)

View File

@ -11,14 +11,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ReduceSumFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.sum(a)
@ -28,16 +31,20 @@ class ReduceSumFloatModule(torch.nn.Module):
def ReduceSumFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceSumDtypeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, a):
return torch.sum(a, dtype=torch.float32)
@ -47,16 +54,20 @@ class ReduceSumDtypeFloatModule(torch.nn.Module):
def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceSumElementTypeBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.bool, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.bool, True),
])
def forward(self, a):
return torch.sum(a)
@ -66,16 +77,20 @@ class ReduceSumElementTypeBoolModule(torch.nn.Module):
def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool))
# ==============================================================================
class ReduceSumDimIntListFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.sum(a, (0, 1))
@ -85,47 +100,60 @@ class ReduceSumDimIntListFloatModule(torch.nn.Module):
def ReduceSumDimIntListFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceSumDimIntListDtypeFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, a):
return torch.sum(a, (0, 1), dtype=torch.float32)
@register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeFloatModule())
@register_test_case(
module_factory=lambda: ReduceSumDimIntListDtypeFloatModule())
def ReduceSumDimIntListDtypeFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceSumDimIntListKeepDimFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.sum(a, (1, 2), keepdim=True)
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimFloatModule())
@register_test_case(
module_factory=lambda: ReduceSumDimIntListKeepDimFloatModule())
def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -138,20 +166,26 @@ class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module):
return torch.sum(a, dim=(-1), keepdim=True)
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule())
def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils):
@register_test_case(
module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule())
def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(1, 12, 7, 7))
# ==============================================================================
class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.sum(a, dim=[])
@ -161,9 +195,12 @@ class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -176,20 +213,25 @@ class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
return torch.sum(a, dim=(-1), keepdim=False)
@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
@register_test_case(
module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool))
# ==============================================================================
class ReduceSumUnsignedIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.sum(a)
@ -199,16 +241,20 @@ class ReduceSumUnsignedIntModule(torch.nn.Module):
def ReduceSumUnsignedIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, low=0, high=100))
# ==============================================================================
class ReduceSumSignedIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.sum(a)
@ -218,16 +264,20 @@ class ReduceSumSignedIntModule(torch.nn.Module):
def ReduceSumSignedIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
# ==============================================================================
class ReduceSumDtypeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int32, True),
])
def forward(self, a):
return torch.sum(a, dtype=torch.int64)
@ -237,16 +287,20 @@ class ReduceSumDtypeIntModule(torch.nn.Module):
def ReduceSumDtypeIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32))
# ==============================================================================
class ReduceSumDimIntListIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.sum(a, (0, 1))
@ -256,16 +310,20 @@ class ReduceSumDimIntListIntModule(torch.nn.Module):
def ReduceSumDimIntListIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100))
# ==============================================================================
class ReduceSumDimIntListDtypeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int32, True),
])
def forward(self, a):
return torch.sum(a, (0, 1), dtype=torch.int64)
@ -275,35 +333,44 @@ class ReduceSumDimIntListDtypeIntModule(torch.nn.Module):
def ReduceSumDimIntListDtypeIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32))
# ==============================================================================
class ReduceSumDimIntListKeepDimIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.sum(a, (1, 2), keepdim=True)
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimIntModule())
@register_test_case(
module_factory=lambda: ReduceSumDimIntListKeepDimIntModule())
def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100))
# ==============================================================================
class ReduceMaxAlongDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1)[0]
@ -313,16 +380,20 @@ class ReduceMaxAlongDim(torch.nn.Module):
def ReduceMaxAlongDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMaxAlongDimNegative(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1)[0]
@ -332,16 +403,20 @@ class ReduceMaxAlongDimNegative(torch.nn.Module):
def ReduceMaxAlongDimNegative_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64))
# ==============================================================================
class ReduceMaxKeepDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1, keepdim=True)[1]
@ -351,26 +426,33 @@ class ReduceMaxKeepDim(torch.nn.Module):
def ReduceMaxKeepDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMaxKeepDimReturnBoth(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a, 1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceMaxKeepDimReturnBoth())
def ReduceMaxKeepDimReturnBoth_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
# ==============================================================================
class ReduceMaxAllDims(torch.nn.Module):
def __init__(self):
@ -379,232 +461,295 @@ class ReduceMaxAllDims(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a)
@register_test_case(module_factory=lambda: ReduceMaxAllDims())
def ReduceMaxAllDims_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
# ==============================================================================
class ReduceMaxNegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a, -1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceMaxNegativeDim())
def ReduceMaxNegativeDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceMaxFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a)
@register_test_case(module_factory=lambda: ReduceMaxFloatModule())
def ReduceMaxFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class ReduceMaxSignedIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.max(a)
@register_test_case(module_factory=lambda: ReduceMaxSignedIntModule())
def ReduceMaxSignedIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
# ==============================================================================
class ReduceMaxUnsignedIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, a):
return torch.ops.aten.max(a)
@register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule())
def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100))
# ==============================================================================
class ReduceL1NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0, ord=1)
@register_test_case(module_factory=lambda: ReduceL1NormModule())
def ReduceL1NormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceL1NormWithDTypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64)
@register_test_case(module_factory=lambda: ReduceL1NormWithDTypeModule())
def ReduceL1NormWithDTypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float32))
# ==============================================================================
class ReduceL2NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0)
@register_test_case(module_factory=lambda: ReduceL2NormModule())
def ReduceL2NormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceLN3NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=0, ord=-3)
@register_test_case(module_factory=lambda: ReduceLN3NormModule())
def ReduceLN3NormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceL3NormAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, dim=None, ord=3)
@register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule())
def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceL3NormKeepDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.linalg.vector_norm(a, keepdim=True, ord=3)
@register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule())
def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceFrobeniusNormModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False)
@register_test_case(module_factory=lambda: ReduceFrobeniusNormModule())
def ReduceFrobeniusNormModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class ReduceFrobeniusNormKeepDimModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=True)
@register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule())
def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils):
module.forward(torch.rand(3, 4, 5))
# ==============================================================================
class MseLossNoReductionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -614,16 +759,17 @@ class MseLossNoReductionModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.mse_loss(x, y, reduction=0)
@register_test_case(module_factory=lambda: MseLossNoReductionModule())
def MseLossNoReductionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4), tu.rand(2, 4))
class MseLossMeanReductionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -633,16 +779,17 @@ class MseLossMeanReductionModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.mse_loss(x, y, reduction=1)
@register_test_case(module_factory=lambda: MseLossMeanReductionModule())
def MseLossMeanReductionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4), tu.rand(2, 4))
class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -652,10 +799,12 @@ class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float64, True),
])
def forward(self, x, y):
return torch.ops.aten.mse_loss(x, y, reduction=2)
@register_test_case(module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule())
def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils):
@register_test_case(
module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule())
def MseLossSumReductionWithDifferentElemTypeModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(2, 4), tu.rand(2, 4).to(torch.float64))

View File

@ -10,7 +10,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ViewExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -19,17 +21,20 @@ class ViewExpandModule(torch.nn.Module):
None,
([6, 4], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3, 4)
@register_test_case(module_factory=lambda: ViewExpandModule())
def ViewExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4))
# ==============================================================================
class ViewExpandOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -38,17 +43,20 @@ class ViewExpandOnesModule(torch.nn.Module):
None,
([1], torch.float32, True),
])
def forward(self, a):
return a.view(1, 1, 1, 1, 1)
@register_test_case(module_factory=lambda: ViewExpandOnesModule())
def ViewExpandOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1))
# ==============================================================================
class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -57,17 +65,21 @@ class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module):
None,
([2, 1, 16, 1, 1], torch.float32, True),
])
def forward(self, a):
return a.view(1, 2, 1, 16, 1, 1, 1, 1)
@register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule())
@register_test_case(
module_factory=lambda: ViewExpandOnesBeforeAndAfterModule())
def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 1, 16, 1, 1))
# ==============================================================================
class ViewExpandOnesMiddleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -76,17 +88,19 @@ class ViewExpandOnesMiddleModule(torch.nn.Module):
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, a):
return a.view(3, 1, 1, 1, 1, 2)
@register_test_case(module_factory=lambda: ViewExpandOnesMiddleModule())
def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))
# ==============================================================================
class ViewCollapseOnesMiddleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -95,55 +109,67 @@ class ViewCollapseOnesMiddleModule(torch.nn.Module):
None,
([3, 1, 1, 1, 1, 2], torch.float32, True),
])
def forward(self, a):
return a.view(3, 1, 2)
@register_test_case(module_factory=lambda: ViewCollapseOnesMiddleModule())
def ViewCollapseOnesMiddleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1, 1, 1, 2))
# ==============================================================================
class ViewDynamicExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, 30, 384], torch.float32, True),
([-9223372036854775808, -9223372036854775808, 30,
384], torch.float32, True),
])
def forward(self, a):
return a.view(2, 4, 5, 6, 12, 32)
@register_test_case(module_factory=lambda: ViewDynamicExpandModule())
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 30, 384))
# ==============================================================================
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(0), a.size(1), 12, 32)
@register_test_case(module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule())
@register_test_case(
module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule())
def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 384))
# ==============================================================================
class ViewCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -152,38 +178,49 @@ class ViewCollapseModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(8)
@register_test_case(module_factory=lambda: ViewCollapseModule())
def ViewCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
# ==============================================================================
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808, -9223372036854775808
], torch.float32, True),
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, a, b, c):
return a.view(a.size(0), int(b), int(c), a.size(3), 384)
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
@register_test_case(
module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3),
torch.tensor(5))
# ==============================================================================
class ViewExpandCollapseWithOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -192,17 +229,20 @@ class ViewExpandCollapseWithOnesModule(torch.nn.Module):
None,
([2, 4, 8, 8], torch.float32, True),
])
def forward(self, a):
return a.view(2, 1, 1, 4, 64)
@register_test_case(module_factory=lambda: ViewExpandCollapseWithOnesModule())
def ViewExpandCollapseWithOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 8, 8))
# ==============================================================================
class ViewExpandCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -211,55 +251,69 @@ class ViewExpandCollapseModule(torch.nn.Module):
None,
([2, 4, 8, 16, 4], torch.float32, True),
])
def forward(self, a):
return a.view(8, 2, 4, 16, 2, 2)
@register_test_case(module_factory=lambda: ViewExpandCollapseModule())
def ViewExpandCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 8, 16, 4))
# ==============================================================================
class ViewDynamicExpandCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, 4, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, 4, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(2, 1, 4, 64)
@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseModule())
def ViewDynamicExpandCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 8, 8))
# ==============================================================================
class ViewDynamicExpandCollapseWithAtenIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, a):
return a.view(2, 1, a.size(1), 64)
@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule())
@register_test_case(
module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule())
def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 8, 8))
# ==============================================================================
class ViewTwoToThreeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -268,17 +322,20 @@ class ViewTwoToThreeStaticModule(torch.nn.Module):
None,
([3, 2], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3)
@register_test_case(module_factory=lambda: ViewTwoToThreeStaticModule())
def ViewTwoToThreeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2))
# ==============================================================================
class ViewTwoFiveThreeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -287,17 +344,20 @@ class ViewTwoFiveThreeStaticModule(torch.nn.Module):
None,
([3, 5, 2], torch.float32, True),
])
def forward(self, a):
return a.view(2, 5, 3)
@register_test_case(module_factory=lambda: ViewTwoFiveThreeStaticModule())
def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2))
# ==============================================================================
class ViewOffsetTestStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -306,17 +366,20 @@ class ViewOffsetTestStaticModule(torch.nn.Module):
None,
([2, 3, 2, 2, 5, 6], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3, 4, 6, 5)
@register_test_case(module_factory=lambda: ViewOffsetTestStaticModule())
def ViewOffsetTestStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 2, 2, 5, 6))
# ==============================================================================
class ViewOffsetBackwardTestStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -325,17 +388,21 @@ class ViewOffsetBackwardTestStaticModule(torch.nn.Module):
None,
([2, 3, 4, 5, 6], torch.float32, True),
])
def forward(self, a):
return a.view(2, 3, 2, 2, 6, 5)
@register_test_case(module_factory=lambda: ViewOffsetBackwardTestStaticModule())
@register_test_case(
module_factory=lambda: ViewOffsetBackwardTestStaticModule())
def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5, 6))
# ==============================================================================
class View1DFoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -344,17 +411,20 @@ class View1DFoldModule(torch.nn.Module):
None,
([-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(-9223372036854775808)
@register_test_case(module_factory=lambda: View1DFoldModule())
def View1DFoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32))
# ==============================================================================
class ViewCollapseInferredDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -363,17 +433,20 @@ class ViewCollapseInferredDimModule(torch.nn.Module):
None,
([2, 3, 4], torch.float32, True),
])
def forward(self, a):
return a.view(-9223372036854775808, 4)
@register_test_case(module_factory=lambda: ViewCollapseInferredDimModule())
def ViewCollapseInferredDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class ViewExpandInferredDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -382,17 +455,20 @@ class ViewExpandInferredDimModule(torch.nn.Module):
None,
([2, 6], torch.float32, True),
])
def forward(self, a):
return a.view(3, -9223372036854775808, 2)
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6))
# ==============================================================================
class ViewExpandDynamicDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -401,17 +477,20 @@ class ViewExpandDynamicDimModule(torch.nn.Module):
None,
([1, -9223372036854775808, 128], torch.float32, True),
])
def forward(self, a):
return a.view(16, 1, 128)
@register_test_case(module_factory=lambda: ViewExpandDynamicDimModule())
def ViewExpandDynamicDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 16, 128))
# ==============================================================================
class ViewFlattenAndExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -420,17 +499,20 @@ class ViewFlattenAndExpandModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(0), a.size(1))
@register_test_case(module_factory=lambda: ViewFlattenAndExpandModule())
def ViewFlattenAndExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(64, 128))
# ==============================================================================
class UnsafeViewExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -439,55 +521,67 @@ class UnsafeViewExpandModule(torch.nn.Module):
None,
([6, 4], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [2, 3, 4])
@register_test_case(module_factory=lambda: UnsafeViewExpandModule())
def UnsafeViewExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4))
# ==============================================================================
class UnsafeViewDynamicExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, 30, 384], torch.float32, True),
([-9223372036854775808, -9223372036854775808, 30,
384], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [2, 4, 5, 6, 12, 32])
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandModule())
def UnsafeViewDynamicExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 30, 384))
# ==============================================================================
class UnsafeViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [a.size(0), a.size(1), 12, 32])
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule())
@register_test_case(
module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule())
def UnsafeViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 384))
# ==============================================================================
class UnsafeViewCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -496,38 +590,52 @@ class UnsafeViewCollapseModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [8])
@register_test_case(module_factory=lambda: UnsafeViewCollapseModule())
def UnsafeViewCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
# ==============================================================================
class UnsafeViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808, -9223372036854775808, -9223372036854775808
], torch.float32, True),
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, a, b, c):
return torch.ops.aten._unsafe_view(a, [a.size(0), int(b), int(c), a.size(3), 384])
return torch.ops.aten._unsafe_view(
a, [a.size(0), int(b), int(c),
a.size(3), 384])
@register_test_case(
module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule())
def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3),
torch.tensor(5))
@register_test_case(module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule())
def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
# ==============================================================================
class UnsafeView1DFoldModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -536,17 +644,20 @@ class UnsafeView1DFoldModule(torch.nn.Module):
None,
([-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._unsafe_view(a, [-9223372036854775808])
@register_test_case(module_factory=lambda: UnsafeView1DFoldModule())
def UnsafeView1DFoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(32))
# ==============================================================================
class ReshapeExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -555,17 +666,20 @@ class ReshapeExpandModule(torch.nn.Module):
None,
([-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.reshape(12, 32)
@register_test_case(module_factory=lambda: ReshapeExpandModule())
def ReshapeExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(384))
# ==============================================================================
class ReshapeCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -574,17 +688,20 @@ class ReshapeCollapseModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.reshape(a, (-1, ))
@register_test_case(module_factory=lambda: ReshapeCollapseModule())
def ReshapeCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
# ==============================================================================
class ViewNoChange1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -593,16 +710,17 @@ class ViewNoChange1dModule(torch.nn.Module):
None,
([-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(6)
@register_test_case(module_factory=lambda: ViewNoChange1dModule())
def ViewNoChange1dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6))
class ViewNoChange2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -611,34 +729,37 @@ class ViewNoChange2dModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(5, 6)
@register_test_case(module_factory=lambda: ViewNoChange2dModule())
def ViewNoChange2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 6))
class ViewNoChange3dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return a.view(4, 5, 6)
@register_test_case(module_factory=lambda: ViewNoChange3dModule())
def ViewNoChange3dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6))
class ViewNoChangeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -647,17 +768,20 @@ class ViewNoChangeStaticModule(torch.nn.Module):
None,
([4, 5, 6], torch.float32, True),
])
def forward(self, a):
return a.view(4, 5, 6)
@register_test_case(module_factory=lambda: ViewNoChangeStaticModule())
def ViewNoChangeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6))
# ==============================================================================
class ReshapeAliasExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -668,17 +792,20 @@ class ReshapeAliasExpandModule(torch.nn.Module):
None,
([-9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._reshape_alias(a, size=(12, 32), stride=(32, 1))
@register_test_case(module_factory=lambda: ReshapeAliasExpandModule())
def ReshapeAliasExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(384))
# ==============================================================================
class ReshapeAliasCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -687,10 +814,10 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten._reshape_alias(a, (8, ), (1, ))
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))

View File

@ -6,6 +6,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class UniformModule(torch.nn.Module):
def __init__(self):
@ -14,9 +15,12 @@ class UniformModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
@ -42,8 +46,10 @@ def UniformModule_basic(module, tu: TestUtils):
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())
# ==============================================================================
class UniformStaticModule(torch.nn.Module):
def __init__(self):
@ -80,18 +86,24 @@ def UniformStaticModule_basic(module, tu: TestUtils):
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())
# ==============================================================================
class BernoulliModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.bernoulli(x)
@ -117,9 +129,12 @@ def BernoulliModule_basic(module, tu: TestUtils):
tu.rand(1024, 2048, 4).double(),
tu.rand(1024, 256, 4).double())
# ==============================================================================
class BernoulliZerosModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -136,9 +151,12 @@ class BernoulliZerosModule(torch.nn.Module):
def BernoulliZerosModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(4, 8).double())
# ==============================================================================
class BernoulliOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -155,18 +173,24 @@ class BernoulliOnesModule(torch.nn.Module):
def BernoulliOnesModule_basic(module, tu: TestUtils):
module.forward(torch.ones(4, 8).double())
# ==============================================================================
class BernoulliFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.ops.aten.bernoulli_(x, 0.4)
@ -192,21 +216,30 @@ def BernoulliFloatModule_basic(module, tu: TestUtils):
tu.rand(1024, 2048, 4).double(),
tu.rand(1024, 512, 4).double())
# ==============================================================================
class BernoulliTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x, px, y, py, z, pz):
a = torch.ops.aten.bernoulli_(x, px)
@ -235,9 +268,12 @@ def BernoulliTensorModule_basic(module, tu: TestUtils):
tu.rand(1024, 512, 8).double(),
tu.rand(1024, 512, 8).double())
# ==============================================================================
class RandLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -256,9 +292,12 @@ class RandLikeModule(torch.nn.Module):
def RandLikeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1024).double())
# ==============================================================================
class RandLikeDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -277,9 +316,12 @@ class RandLikeDtypeModule(torch.nn.Module):
def RandLikeDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1024).double())
# ==============================================================================
class RandIntLowModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -297,9 +339,12 @@ class RandIntLowModule(torch.nn.Module):
def RandIntLowModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandIntLowDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -308,7 +353,10 @@ class RandIntLowDtypeModule(torch.nn.Module):
None,
])
def forward(self):
a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64)
a = torch.ops.aten.randint(low=1,
high=1000,
size=[128, 256, 512],
dtype=torch.float64)
mean = torch.mean(a)
return mean

View File

@ -29,7 +29,8 @@ class AddIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AddIntModule())
def AddIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================
@ -52,7 +53,8 @@ class SubIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: SubIntModule())
def SubIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================
@ -98,7 +100,8 @@ class MulIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: MulIntModule())
def MulIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================
@ -337,9 +340,12 @@ class BoolIntConstantModule(torch.nn.Module):
def BoolIntConstantModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class AtenIntTensorByteDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -348,10 +354,10 @@ class AtenIntTensorByteDtypeModule(torch.nn.Module):
None,
([], torch.uint8, True),
])
def forward(self, val):
return int(val)
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
@ -359,7 +365,9 @@ def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
# ==============================================================================
class AtenIntTensorCharDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -368,10 +376,10 @@ class AtenIntTensorCharDtypeModule(torch.nn.Module):
None,
([], torch.int8, True),
])
def forward(self, val):
return int(val)
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))

View File

@ -29,7 +29,8 @@ class NeIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: NeIntModule())
def NeIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================
@ -52,7 +53,8 @@ class EqIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: EqIntModule())
def EqIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================
@ -75,7 +77,8 @@ class GtIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: GtIntModule())
def GtIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================
@ -98,7 +101,8 @@ class GeIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: GeIntModule())
def GeIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
module.forward(tu.randint(low=-100, high=100),
tu.randint(low=-100, high=100))
# ==============================================================================

View File

@ -11,14 +11,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class SliceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return x[0:5:1, 1:3:1, 2:4:1]
@ -31,7 +34,9 @@ def SliceModule_basic(module, tu: TestUtils):
# ==============================================================================
class SliceStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -51,14 +56,17 @@ def SliceStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
@ -71,55 +79,68 @@ class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7))
# ==============================================================================
class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return x[:-8, -7:, :]
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
@register_test_case(
module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7))
# ==============================================================================
class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return x[-8:3:1, 1:3:1, 2:4:1]
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
@register_test_case(
module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7))
# ==============================================================================
class SliceEndSleStartModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
@ -132,17 +153,20 @@ class SliceEndSleStartModule(torch.nn.Module):
def SliceEndSleStartModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7))
# ==============================================================================
class SliceStartEqEndModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
@ -155,16 +179,20 @@ class SliceStartEqEndModule(torch.nn.Module):
def SliceStartEqEndModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 7))
# ==============================================================================
class SliceSizeTwoStepModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return x[0:5:2, 0:3:2, 0:4:2]
@ -174,9 +202,12 @@ class SliceSizeTwoStepModule(torch.nn.Module):
def SliceSizeTwoStepModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 5, 17))
# ==============================================================================
class SliceNegIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -193,9 +224,12 @@ class SliceNegIdxModule(torch.nn.Module):
def SliceNegIdxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 9))
# ==============================================================================
class SliceSingleIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -212,9 +246,12 @@ class SliceSingleIdxModule(torch.nn.Module):
def SliceSingleIdxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8))
# ==============================================================================
class SliceWholeTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -231,9 +268,12 @@ class SliceWholeTensorModule(torch.nn.Module):
def SliceWholeTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8))
# ==============================================================================
class SelectIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -250,11 +290,14 @@ class SelectIntModule(torch.nn.Module):
def SelectIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(5, 5, high=10))
# ==============================================================================
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -265,13 +308,21 @@ class SliceScatterModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
return torch.ops.aten.slice_scatter(x,
src,
dim=1,
start=0,
end=1,
step=1)
@register_test_case(module_factory=lambda: SliceScatterModule())
def SliceScatterModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(6, 1))
class SliceScatterZeroDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -282,7 +333,12 @@ class SliceScatterZeroDimModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1)
return torch.ops.aten.slice_scatter(x,
src,
dim=0,
start=0,
end=1,
step=1)
@register_test_case(module_factory=lambda: SliceScatterZeroDimModule())
@ -314,7 +370,9 @@ class SliceScatterNegativeDimModule(torch.nn.Module):
def SliceScatterNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(1, 8))
class SliceScatterStepVariationModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -325,14 +383,21 @@ class SliceScatterStepVariationModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2)
return torch.ops.aten.slice_scatter(x,
src,
dim=1,
start=0,
end=1,
step=2)
@register_test_case(module_factory=lambda: SliceScatterStepVariationModule())
def SliceScatterStepVariationModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(6, 1))
class SliceScatterStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -343,21 +408,29 @@ class SliceScatterStaticModule(torch.nn.Module):
([6, 1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
return torch.ops.aten.slice_scatter(x,
src,
dim=1,
start=0,
end=1,
step=1)
@register_test_case(module_factory=lambda: SliceScatterStaticModule())
def SliceScatterStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(6, 1))
class SelectScatterModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, x, src):
@ -368,7 +441,9 @@ class SelectScatterModule(torch.nn.Module):
def SelectScattertModule_basic(module, tu: TestUtils):
module.forward(torch.rand(6, 8, 5), torch.rand(8, 5))
class SelectScatterStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -386,16 +461,20 @@ class SelectScatterStaticModule(torch.nn.Module):
def SelectScattertStaticModule_basic(module, tu: TestUtils):
module.forward(torch.rand(6, 8, 5), torch.rand(6, 5))
# ==============================================================================
class NarrowHorizontalTest(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.narrow(x, dim=0, start=0, length=2)
@ -405,17 +484,20 @@ class NarrowHorizontalTest(torch.nn.Module):
def NarrowHorizontalTest_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 3))
# ==============================================================================
class NarrowVerticalTest(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.narrow(x, dim=1, start=0, length=2)
@ -425,9 +507,12 @@ class NarrowVerticalTest(torch.nn.Module):
def NarrowVerticalTest_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 3))
# ==============================================================================
class NarrowHorizontalTest2(torch.nn.Module):
def __init__(self):
super().__init__()
@ -444,10 +529,12 @@ class NarrowHorizontalTest2(torch.nn.Module):
def NarrowHorizontalTest2_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4))
# ==============================================================================
class NarrowVerticalTest2(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class SqueezeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -25,8 +26,7 @@ class SqueezeStaticModule(torch.nn.Module):
return torch.squeeze(a)
@register_test_case(
module_factory=lambda: SqueezeStaticModule())
@register_test_case(module_factory=lambda: SqueezeStaticModule())
def SqueezeModule_static(module, tu: TestUtils):
module.forward(tu.rand(1, 7, 1, 3, 1))
@ -35,6 +35,7 @@ def SqueezeModule_static(module, tu: TestUtils):
class SqueezeAllUnitDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -47,8 +48,7 @@ class SqueezeAllUnitDimModule(torch.nn.Module):
return torch.squeeze(a)
@register_test_case(
module_factory=lambda: SqueezeAllUnitDimModule())
@register_test_case(module_factory=lambda: SqueezeAllUnitDimModule())
def SqueezeModule_allUnitDim(module, tu: TestUtils):
module.forward(tu.rand(1, 1))
@ -57,6 +57,7 @@ def SqueezeModule_allUnitDim(module, tu: TestUtils):
class SqueezeBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -70,8 +71,7 @@ class SqueezeBroadcastModule(torch.nn.Module):
return a * b.squeeze()
@register_test_case(
module_factory=lambda: SqueezeBroadcastModule())
@register_test_case(module_factory=lambda: SqueezeBroadcastModule())
def SqueezeModule_broadcast(module, tu: TestUtils):
module.forward(tu.rand(4, 3), tu.rand())
@ -80,6 +80,7 @@ def SqueezeModule_broadcast(module, tu: TestUtils):
class SqueezeDimStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -92,8 +93,7 @@ class SqueezeDimStaticModule(torch.nn.Module):
return torch.squeeze(a, 0)
@register_test_case(
module_factory=lambda: SqueezeDimStaticModule())
@register_test_case(module_factory=lambda: SqueezeDimStaticModule())
def SqueezeDimModule_static(module, tu: TestUtils):
module.forward(tu.rand(1, 7))
@ -102,20 +102,21 @@ def SqueezeDimModule_static(module, tu: TestUtils):
class SqueezeDimDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, 1, 384, -9223372036854775808, 1], torch.float32, True),
([-9223372036854775808, 1, 384, -9223372036854775808,
1], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, 4)
@register_test_case(
module_factory=lambda: SqueezeDimDynamicModule())
@register_test_case(module_factory=lambda: SqueezeDimDynamicModule())
def SqueezeDimModule_dynamic(module, tu: TestUtils):
module.forward(tu.rand(8, 1, 384, 12, 1))
@ -124,20 +125,21 @@ def SqueezeDimModule_dynamic(module, tu: TestUtils):
class SqueezeDimNegDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, -9223372036854775808, 1, 384, -9223372036854775808, 1], torch.float32, True),
([1, -9223372036854775808, 1, 384, -9223372036854775808,
1], torch.float32, True),
])
def forward(self, a):
return torch.squeeze(a, -6)
@register_test_case(
module_factory=lambda: SqueezeDimNegDimModule())
@register_test_case(module_factory=lambda: SqueezeDimNegDimModule())
def SqueezeDimModule_negDim(module, tu: TestUtils):
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
@ -146,6 +148,7 @@ def SqueezeDimModule_negDim(module, tu: TestUtils):
class SqueezeDimIdentityModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -158,8 +161,7 @@ class SqueezeDimIdentityModule(torch.nn.Module):
return torch.squeeze(a, 0)
@register_test_case(
module_factory=lambda: SqueezeDimIdentityModule())
@register_test_case(module_factory=lambda: SqueezeDimIdentityModule())
def SqueezeDimModule_identity(module, tu: TestUtils):
module.forward(tu.rand(4, 1, 3))
@ -168,6 +170,7 @@ def SqueezeDimModule_identity(module, tu: TestUtils):
class SqueezeDimUnitDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -180,7 +183,6 @@ class SqueezeDimUnitDimModule(torch.nn.Module):
return torch.squeeze(a, 0)
@register_test_case(
module_factory=lambda: SqueezeDimUnitDimModule())
@register_test_case(module_factory=lambda: SqueezeDimUnitDimModule())
def SqueezeDimModule_unitDim(module, tu: TestUtils):
module.forward(tu.rand(1))

View File

@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class MeanModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -28,9 +30,12 @@ class MeanModule(torch.nn.Module):
def MeanModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class MeanDynamicSizesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -47,16 +52,20 @@ class MeanDynamicSizesModule(torch.nn.Module):
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class MeanDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, dtype=torch.float32)
@ -66,16 +75,22 @@ class MeanDtypeModule(torch.nn.Module):
def MeanDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class MeanLargeInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x)
@ -85,16 +100,20 @@ class MeanLargeInputModule(torch.nn.Module):
def MeanLargeInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
# ==============================================================================
class MeanDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 2))
@ -104,16 +123,22 @@ class MeanDimModule(torch.nn.Module):
def MeanDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class MeanDimLargeInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 2))
@ -126,14 +151,17 @@ def MeanDimLargeInputModule_basic(module, tu: TestUtils):
# ==============================================================================
class MeanDimDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, ), dtype=torch.float32)
@ -143,16 +171,20 @@ class MeanDimDtypeModule(torch.nn.Module):
def MeanDimDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class MeanDimKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (1, 2), keepdim=True)
@ -162,16 +194,20 @@ class MeanDimKeepdimModule(torch.nn.Module):
def MeanDimKeepdimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimAllReduceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 1, 2))
@ -181,16 +217,20 @@ class MeanDimAllReduceModule(torch.nn.Module):
def MeanDimAllReduceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimAllReduceKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True)
@ -200,16 +240,20 @@ class MeanDimAllReduceKeepdimModule(torch.nn.Module):
def MeanDimAllReduceKeepdimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimNegativeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (-1, 1))
@ -222,14 +266,17 @@ def MeanDimNegativeModule_basic(module, tu: TestUtils):
# ==============================================================================
class MeanDimEmptyDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, dim=[])
@ -239,16 +286,20 @@ class MeanDimEmptyDimModule(torch.nn.Module):
def MeanDimEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimNoneDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, dim=None)
@ -258,74 +309,94 @@ class MeanDimNoneDimModule(torch.nn.Module):
def MeanDimNoneDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class VarUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=True)
@register_test_case(module_factory=lambda: VarUnbiasedModule())
def VarUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class VarBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=False)
@register_test_case(module_factory=lambda: VarBiasedModule())
def VarBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=True)
@register_test_case(module_factory=lambda: StdUnbiasedModule())
def StdUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=False)
@register_test_case(module_factory=lambda: StdBiasedModule())
def StdBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
@ -342,7 +413,8 @@ class StdDimKeepDimFalseModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=(1, 2), keepdim=False)
@ -364,7 +436,8 @@ class StdDimKeepDimTrueModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=(0, 1, 2), keepdim=True)
@ -386,7 +459,8 @@ class StdDimBiasedModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=(0, 2), unbiased=False)
@ -408,7 +482,8 @@ class StdDimEmptyDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=[], keepdim=False)
@ -430,7 +505,8 @@ class StdDimNoneDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=None, keepdim=False)
@ -452,7 +528,8 @@ class VarDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(0, 2), keepdim=True)
@ -474,7 +551,8 @@ class VarDimUnbiasedModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(0, 2), unbiased=True, keepdim=True)
@ -496,7 +574,8 @@ class VarDimBiasedModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(0, 1), unbiased=False, keepdim=True)
@ -518,7 +597,8 @@ class VarDimSingleDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(0, ), keepdim=True)
@ -540,7 +620,8 @@ class VarDimMultiDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[0, 2], keepdim=False)
@ -562,7 +643,8 @@ class VarDimAllDimReduceModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=True)
@ -584,7 +666,8 @@ class VarDimNegativeModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(-1, 1), keepdim=True)
@ -606,7 +689,8 @@ class VarDimEmptyDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[], keepdim=False)
@ -628,7 +712,8 @@ class VarDimNoneDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, keepdim=False)
@ -650,7 +735,8 @@ class VarCorrectionModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, correction=2)
@ -672,13 +758,15 @@ class VarCorrectionSingleDimReduceModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[1], correction=1)
@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule())
@register_test_case(
module_factory=lambda: VarCorrectionSingleDimReduceModule())
def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
@ -694,7 +782,8 @@ class VarCorrectionAllDimReduceModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x,
@ -719,7 +808,8 @@ class VarCorrectionKeepDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True)
@ -741,7 +831,8 @@ class VarCorrectionNoneModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, correction=None)
@ -763,7 +854,8 @@ class VarCorrectionEmptyDimModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[], correction=2)
@ -785,7 +877,10 @@ class VarCorrectionLargeInputModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=[2, 3], correction=2)
@ -807,10 +902,16 @@ class VarMeanCorrectionModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([
-9223372036854775808, -9223372036854775808, -9223372036854775808,
-9223372036854775808
], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True)
return torch.ops.aten.var_mean(x,
dim=[1, 2],
correction=2,
keepdim=True)
@register_test_case(module_factory=lambda: VarMeanCorrectionModule())
@ -829,10 +930,14 @@ class VarMeanCorrectionNoneModule(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var_mean(x, dim=None, correction=None, keepdim=False)
return torch.ops.aten.var_mean(x,
dim=None,
correction=None,
keepdim=False)
@register_test_case(module_factory=lambda: VarMeanCorrectionNoneModule())

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class Threshold1dIntI32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -21,16 +22,17 @@ class Threshold1dIntI32Module(torch.nn.Module):
None,
([-9223372036854775808], torch.int32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2)
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(4, high=10).to(torch.int32))
class Threshold1dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -39,16 +41,17 @@ class Threshold1dIntModule(torch.nn.Module):
None,
([-9223372036854775808], torch.int64, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2)
@register_test_case(module_factory=lambda: Threshold1dIntModule())
def Threshold1dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, high=10))
class Threshold2dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -57,34 +60,37 @@ class Threshold2dIntModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.int64, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 0.5, 2)
@register_test_case(module_factory=lambda: Threshold2dIntModule())
def Threshold2dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=10))
class Threshold3dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2.2)
@register_test_case(module_factory=lambda: Threshold3dIntModule())
def Threshold3dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, 6, high=10))
class Threshold1dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -93,16 +99,17 @@ class Threshold1dFloatModule(torch.nn.Module):
None,
([-9223372036854775808], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2)
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
def Threshold1dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4))
class Threshold2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -111,34 +118,37 @@ class Threshold2dFloatModule(torch.nn.Module):
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 0.5, 2)
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
def Threshold2dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5))
class Threshold3dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1.4, 2.0)
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
def Threshold3dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6))
class ThresholdBackward1dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -148,16 +158,17 @@ class ThresholdBackward1dIntModule(torch.nn.Module):
([-9223372036854775808], torch.int64, True),
([-9223372036854775808], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, high=10), tu.randint(4, high=8))
class ThresholdBackward2dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -167,35 +178,39 @@ class ThresholdBackward2dIntModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 0.5)
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8))
class ThresholdBackward3dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8))
class ThresholdBackward1dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -205,16 +220,17 @@ class ThresholdBackward1dFloatModule(torch.nn.Module):
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4), torch.randn(4))
class ThresholdBackward2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -224,35 +240,39 @@ class ThresholdBackward2dFloatModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 0.5)
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5), torch.randn(4, 5))
class ThresholdBackward3dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1.4)
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
class ThresholdBackward1dMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -262,16 +282,17 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
([-9223372036854775808], torch.float32, True),
([-9223372036854775808], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4), tu.randint(4, high=10))
class ThresholdBackward2dMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -281,29 +302,32 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
([-9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 0.5)
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5))
class ThresholdBackward3dMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.float32, True),
([-9223372036854775808, -9223372036854775808,
-9223372036854775808], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1.4)
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10))

View File

@ -18,7 +18,10 @@ class TypeConversionF32ToF64Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True)
])
def forward(self, x):
return x.to(torch.float64)
@ -34,7 +37,10 @@ class TypeConversionF64ToF32Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float64, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float64, True)
])
def forward(self, x):
return x.to(torch.float32)
@ -50,7 +56,9 @@ class TypeConversionI32ToI64Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.int32, True)])
@annotate_args([
None, ([-9223372036854775808, -9223372036854775808], torch.int32, True)
])
def forward(self, x):
return x.to(torch.int64)
@ -66,7 +74,9 @@ class TypeConversionI64ToI32Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)])
@annotate_args([
None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)
])
def forward(self, x):
return x.to(torch.int32)
@ -82,7 +92,9 @@ class TypeConversionI1ToI32Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
@annotate_args([
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
])
def forward(self, x):
return x.to(torch.int32)
@ -99,7 +111,9 @@ class TypeConversionI1ToI64Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
@annotate_args([
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
])
def forward(self, x):
return x.to(torch.int64)
@ -116,7 +130,9 @@ class TypeConversionI1ToF32Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
@annotate_args([
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
])
def forward(self, x):
return x.to(torch.float32)
@ -133,7 +149,9 @@ class TypeConversionI1ToF64Module(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
@annotate_args([
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
])
def forward(self, x):
return x.to(torch.float64)
@ -153,7 +171,10 @@ class ToDtypeLayoutNoneModule(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True)
])
def forward(self, x):
return torch.ops.aten.to(x,
dtype=torch.float64,
@ -176,7 +197,10 @@ class ToDtypeLayoutStridedModule(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True)
])
def forward(self, x):
return torch.ops.aten.to(x,
dtype=torch.float64,
@ -245,7 +269,10 @@ class PrimsConvertElementTypeModule(torch.nn.Module):
super().__init__()
@export
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
@annotate_args([
None,
([-9223372036854775808, -9223372036854775808], torch.float32, True)
])
def forward(self, x):
return torch.ops.prims.convert_element_type(x, dtype=torch.int64)

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -30,11 +31,11 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, high=10).type(torch.int32),
tu.randint(4, high=10))
tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10))
class TypePromotionDifferentCategoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -55,6 +56,7 @@ def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -75,6 +77,7 @@ def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -95,6 +98,7 @@ def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
class TypePromotionAlphaWiderModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class ResNet18Module(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -24,7 +25,8 @@ class ResNet18Module(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, 3, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, img):
return self.resnet.forward(img)
@ -36,6 +38,7 @@ def ResNet18Module_basic(module, tu: TestUtils):
class ResNet18StaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -58,6 +61,7 @@ def ResNet18StaticModule_basic(module, tu: TestUtils):
class IouOfModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -84,7 +88,9 @@ class IouOfModule(torch.nn.Module):
def IouOfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 4), tu.rand(1024, 4))
class MobilenetV2Module(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -95,11 +101,13 @@ class MobilenetV2Module(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, 3, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, img):
return self.mobilenetv2.forward(img)
# TODO (cathyzhyi) The runtime assertion for conv2d with group != 1 is exposed
# after aten.hardtanh is implemented. Reenable once the the runtime assertion
# is fixed.
@ -107,7 +115,9 @@ class MobilenetV2Module(torch.nn.Module):
def MobilenetV2Module_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 224, 224))
class MobilenetV3Module(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
@ -118,7 +128,8 @@ class MobilenetV3Module(torch.nn.Module):
@export
@annotate_args([
None,
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
([-9223372036854775808, 3, -9223372036854775808,
-9223372036854775808], torch.float32, True),
])
def forward(self, img):
return self.mobilenetv3.forward(img)