mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix indentation and spacing for E2E tests
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1563/merge snapshot-20221124.667
parent
e45ad313d4
commit
3790a4270e
|
@ -10,6 +10,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"QuantizedMLP_basic",
|
||||
}
|
||||
|
||||
|
||||
def register_all_tests():
|
||||
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
|
||||
# Side-effecting import statements.
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class ArangeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -20,16 +21,17 @@ class ArangeIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeIntModule())
|
||||
def ArangeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -37,16 +39,17 @@ class ArangeFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(5.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeFloatModule())
|
||||
def ArangeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeZeroElementOutputModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -54,16 +57,17 @@ class ArangeZeroElementOutputModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeZeroElementOutputModule())
|
||||
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -71,16 +75,17 @@ class ArangeStartIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0, 5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartIntModule())
|
||||
def ArangeStartIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -88,16 +93,17 @@ class ArangeStartFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0.0, 5.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartFloatModule())
|
||||
def ArangeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeNegativeStartIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -105,16 +111,17 @@ class ArangeNegativeStartIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-10, 5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeNegativeStartIntModule())
|
||||
def ArangeNegativeStartIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeNegativeStartFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -122,16 +129,17 @@ class ArangeNegativeStartFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1.4, 5.7)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeNegativeStartFloatModule())
|
||||
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartStepIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -139,16 +147,17 @@ class ArangeStartStepIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0, 5, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartStepIntModule())
|
||||
def ArangeStartStepIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartStepFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -156,16 +165,17 @@ class ArangeStartStepFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1, 5, 1.3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartStepFloatModule())
|
||||
def ArangeStartStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartNegativeStepIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -173,16 +183,17 @@ class ArangeStartNegativeStepIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(10, 1, -2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepIntModule())
|
||||
def ArangeStartNegativeStepIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -190,16 +201,18 @@ class ArangeStartNegativeStepFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1, -15, -3.4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ArangeStartNegativeStepFloatModule())
|
||||
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeDtypeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -207,16 +220,17 @@ class ArangeDtypeFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(-1, 15, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeDtypeFloatModule())
|
||||
def ArangeDtypeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeDtypeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -224,16 +238,17 @@ class ArangeDtypeIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(0.2, 5.0, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeDtypeIntModule())
|
||||
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ArangeFalsePinMemoryModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -241,10 +256,10 @@ class ArangeFalsePinMemoryModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.arange(5.0, dtype=torch.int64, pin_memory=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArangeFalsePinMemoryModule())
|
||||
def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
|
@ -10,7 +10,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArgmaxModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -19,7 +21,6 @@ class ArgmaxModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.argmax(a)
|
||||
|
||||
|
@ -28,38 +29,47 @@ class ArgmaxModule(torch.nn.Module):
|
|||
def ArgmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArgmaxWithDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.argmax(a, dim=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
|
||||
def ArgmaxModule_with_dim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ArgmaxKeepDimsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.argmax(a, 0, True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
|
||||
def ArgmaxModule_keepDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 6))
|
||||
|
|
|
@ -20,8 +20,10 @@ class SoftmaxBackwardModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._softmax_backward_data(grad_output,
|
||||
|
@ -58,6 +60,7 @@ def TanhBackward_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -66,9 +69,18 @@ class ConvolutionBackwardModule2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
|
@ -100,9 +112,18 @@ class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
|
@ -157,8 +178,10 @@ class LogSoftmaxBackwardModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._log_softmax_backward_data(grad_output,
|
||||
|
|
|
@ -49,8 +49,10 @@ class BmmModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.bmm(lhs, rhs)
|
||||
|
@ -109,14 +111,16 @@ def IsFloatingPointFloat_True(module, tu: TestUtils):
|
|||
|
||||
|
||||
class ContainsIntList(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None
|
||||
])
|
||||
@annotate_args([None])
|
||||
def forward(self):
|
||||
return torch.ops.aten.__contains__([1,2,3], 3)
|
||||
return torch.ops.aten.__contains__([1, 2, 3], 3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ContainsIntList())
|
||||
def ContainsIntList_True(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -126,14 +130,16 @@ def ContainsIntList_True(module, tu: TestUtils):
|
|||
|
||||
|
||||
class ContainsIntListFalse(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None
|
||||
])
|
||||
@annotate_args([None])
|
||||
def forward(self):
|
||||
return torch.ops.aten.__contains__([1,2,3], 4)
|
||||
return torch.ops.aten.__contains__([1, 2, 3], 4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ContainsIntListFalse())
|
||||
def ContainsIntList_False(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
@ -320,7 +326,10 @@ class FlattenDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, 9, 3, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
9, 3, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.flat(x)
|
||||
|
@ -365,7 +374,10 @@ class PadModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
pad = [0, 1, 2, 3]
|
||||
|
@ -389,7 +401,10 @@ class PadWithNoneValModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
pad = [0, 1, 2, 3]
|
||||
|
@ -413,7 +428,10 @@ class ConstantPadNdModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
|
||||
|
@ -457,7 +475,8 @@ class ConstantPadNdPartialStaticModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 20, 20, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([1, 1, 20, 20, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf'))
|
||||
|
@ -580,9 +599,12 @@ class TensorsConcatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.cat([x, y, z], 1)
|
||||
|
@ -604,9 +626,12 @@ class TensorsConcatNegativeDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.cat([x, y, z], dim=-2)
|
||||
|
@ -628,8 +653,10 @@ class GatherModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 2, indices)
|
||||
|
@ -644,18 +671,22 @@ def GatherModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class GatherRandomIndexModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GatherRandomIndexModule())
|
||||
def GatherRandomIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), tu.randint(2, 3, 4, high=3))
|
||||
|
@ -665,9 +696,10 @@ def GatherRandomIndexModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class Gather2DInputModdule(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -677,6 +709,7 @@ class Gather2DInputModdule(torch.nn.Module):
|
|||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Gather2DInputModdule())
|
||||
def Gather2DInputModdule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), torch.tensor([[1, 2, 3], [4, 3, 2]]))
|
||||
|
@ -809,6 +842,7 @@ def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class EmbeddingModuleI32Static(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -870,7 +904,8 @@ class SoftmaxIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
@ -892,7 +927,8 @@ class _SoftmaxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._softmax(tensor, 0, False)
|
||||
|
@ -916,7 +952,8 @@ class SoftmaxIntNegDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
@ -940,7 +977,8 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
@ -962,7 +1000,8 @@ class _LogSoftmaxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False)
|
||||
|
@ -1129,10 +1168,12 @@ class BroadcastZeroRankInputStaticModule(torch.nn.Module):
|
|||
return torch.ops.aten.sub(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BroadcastZeroRankInputStaticModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: BroadcastZeroRankInputStaticModule())
|
||||
def BroadcastZeroRankInputStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 8), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -1154,6 +1195,7 @@ class RollModule(torch.nn.Module):
|
|||
def RollModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -1175,6 +1217,7 @@ class RepeatModule(torch.nn.Module):
|
|||
def RepeatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -1231,7 +1274,8 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.log_softmax.forward(tensor)
|
||||
|
@ -1479,7 +1523,8 @@ class NumelModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.numel(input)
|
||||
|
@ -1772,11 +1817,12 @@ class IndexTensorModule3dInput(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index):
|
||||
return torch.ops.aten.index(x, (index,))
|
||||
return torch.ops.aten.index(x, (index, ))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorModule3dInput())
|
||||
|
@ -1795,7 +1841,8 @@ class IndexTensorSelectDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, ind):
|
||||
|
@ -1806,6 +1853,7 @@ class IndexTensorSelectDimModule(torch.nn.Module):
|
|||
def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 6), tu.randint(2, 3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -1817,17 +1865,22 @@ class IndexTensorMultiInput(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([3, 3], torch.int64, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (index1, index2,))
|
||||
return torch.ops.aten.index(x, (
|
||||
index1,
|
||||
index2,
|
||||
))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInput())
|
||||
def IndexTensorMultiInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3), tu.randint(3, high=3))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3),
|
||||
tu.randint(3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1841,17 +1894,22 @@ class IndexTensorMultiInputOneDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([6, 1], torch.int64, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (index1, index2,))
|
||||
return torch.ops.aten.index(x, (
|
||||
index1,
|
||||
index2,
|
||||
))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim())
|
||||
def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4),
|
||||
tu.randint(3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1865,7 +1923,8 @@ class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 1], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
@ -1895,7 +1954,8 @@ class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 1], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
@ -1926,7 +1986,8 @@ class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 2], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
@ -1956,7 +2017,10 @@ class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([4, 1], torch.int64, True),
|
||||
([1, 3], torch.int64, True),
|
||||
([-9223372036854775808, 3], torch.int64, True),
|
||||
|
@ -1984,7 +2048,10 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([4, 2], torch.int64, True),
|
||||
([4, 2], torch.int64, True),
|
||||
])
|
||||
|
@ -1992,9 +2059,11 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module):
|
|||
return torch.ops.aten.index(x, (index1, None, index2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous())
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputNonContiguous())
|
||||
def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1))
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3),
|
||||
tu.randint(4, 2, high=1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2008,21 +2077,24 @@ class IndexTensorMultiInputThreeIndexers(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
([8, 4, 2], torch.int64, True),
|
||||
([8, 1, 1], torch.int64, True),
|
||||
([4, 2], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2, index3):
|
||||
return torch.ops.aten.index(x, (None, None, index1, None, index2, index3))
|
||||
return torch.ops.aten.index(x,
|
||||
(None, None, index1, None, index2, index3))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers())
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputThreeIndexers())
|
||||
def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 2, 4, 4, 5, 3),
|
||||
tu.randint(8, 4, 2, high=3),
|
||||
tu.randint(8, 1, 1, high=4),
|
||||
tu.randint(4, 2, high=2))
|
||||
module.forward(tu.rand(1, 2, 4, 4, 5, 3), tu.randint(8, 4, 2, high=3),
|
||||
tu.randint(8, 1, 1, high=4), tu.randint(4, 2, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2036,7 +2108,10 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([2, 2], torch.int64, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
@ -2044,9 +2119,11 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
|
|||
return torch.ops.aten.index(x, (None, index1, index2, None))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter())
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputContiguousCenter())
|
||||
def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3), tu.randint(2, high=2))
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3),
|
||||
tu.randint(2, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2083,7 +2160,8 @@ class IndexTensorHackedTwinModule3dInput(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index):
|
||||
|
@ -2108,7 +2186,10 @@ class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims(
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([4, 1], torch.int64, True),
|
||||
([1, 3], torch.int64, True),
|
||||
([-9223372036854775808, 3], torch.int64, True),
|
||||
|
@ -2137,7 +2218,8 @@ class SquareModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.square(x)
|
||||
|
@ -2336,7 +2418,8 @@ class ExpandAsFloatModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 1, 1], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.expand_as(x, y)
|
||||
|
@ -2356,7 +2439,8 @@ class ExpandAsIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 1], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.expand_as(x, y)
|
||||
|
@ -2364,8 +2448,8 @@ class ExpandAsIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ExpandAsIntModule())
|
||||
def ExpandAsIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 1, 1, high=100),
|
||||
tu.randint(4, 5, 6, high=200))
|
||||
module.forward(tu.randint(1, 1, 1, high=100), tu.randint(4, 5, 6,
|
||||
high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2379,8 +2463,10 @@ class CopyModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
@ -2419,8 +2505,10 @@ class CopyWithDifferentDTypesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
@ -2463,7 +2551,8 @@ class ToCopyModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x)
|
||||
|
@ -2482,7 +2571,8 @@ class ToCopyWithDTypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x, dtype=torch.int64)
|
||||
|
@ -2501,7 +2591,8 @@ class ToCopyWithDTypeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x, dtype=torch.int64, pin_memory=False)
|
||||
|
@ -2543,7 +2634,8 @@ class FlipModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.flip(x, [1, 2])
|
||||
|
@ -2565,7 +2657,8 @@ class DetachModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.detach(x)
|
||||
|
@ -2650,9 +2743,12 @@ class BaddbmmDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2)
|
||||
|
@ -2692,9 +2788,12 @@ class BaddbmmDifferentDtypesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2)
|
||||
|
@ -2714,9 +2813,12 @@ class BaddbmmWithAlphaModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2, alpha=5)
|
||||
|
@ -2735,9 +2837,12 @@ class BaddbmmWithBetaModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2, beta=0.5)
|
||||
|
@ -2756,9 +2861,12 @@ class BaddbmmWithAlphaBetaModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2, beta=6, alpha=2.4)
|
||||
|
@ -2787,7 +2895,7 @@ class BaddbmmBroadcast1DInputModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BaddbmmBroadcast1DInputModule())
|
||||
def BaddbmmBroadcast1DInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1,), tu.rand(5, 2, 9), tu.rand(5, 9, 7))
|
||||
module.forward(tu.rand(1, ), tu.rand(5, 2, 9), tu.rand(5, 9, 7))
|
||||
|
||||
|
||||
class BaddbmmBroadcast2DInputModule(torch.nn.Module):
|
||||
|
@ -2841,7 +2949,10 @@ class NumpyTRankNDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.ops.aten.numpy_T(lhs)
|
||||
|
@ -2908,6 +3019,7 @@ class NumpyTRank0Module(torch.nn.Module):
|
|||
def NumpyTRank0Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||
|
||||
|
||||
class AtenEmbeddingBagSumExample(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -2921,15 +3033,26 @@ class AtenEmbeddingBagSumExample(torch.nn.Module):
|
|||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, weight, indices, offsets):
|
||||
return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None)
|
||||
return torch.ops.aten.embedding_bag(weight,
|
||||
indices,
|
||||
offsets,
|
||||
scale_grad_by_freq=False,
|
||||
mode=0,
|
||||
sparse=False,
|
||||
per_sample_weights=None,
|
||||
include_last_offset=False,
|
||||
padding_idx=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample())
|
||||
def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils):
|
||||
weight = torch.rand(100, 10)
|
||||
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
weight = torch.rand(100, 10)
|
||||
indices = torch.LongTensor(
|
||||
[0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
|
||||
module.forward(weight, indices, offsets)
|
||||
|
||||
|
||||
class Aten_EmbeddingBagExample(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -2945,15 +3068,19 @@ class Aten_EmbeddingBagExample(torch.nn.Module):
|
|||
def forward(self, weight, indices, offsets):
|
||||
return torch.ops.aten._embedding_bag(weight, indices, offsets)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_EmbeddingBagExample())
|
||||
def Aten_EmbeddingBagExample_basic(module, tu: TestUtils):
|
||||
weight = torch.rand(100, 10)
|
||||
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
weight = torch.rand(100, 10)
|
||||
indices = torch.LongTensor(
|
||||
[0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
|
||||
module.forward(weight, indices, offsets)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class CumsumModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -2962,15 +3089,18 @@ class CumsumModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumsum(val, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CumsumModule())
|
||||
def CumsumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
|
||||
class CumsumStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -2984,29 +3114,37 @@ class CumsumStaticModule(torch.nn.Module):
|
|||
def forward(self, val):
|
||||
return torch.ops.aten.cumsum(val, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: CumsumStaticModule())
|
||||
def CumsumStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenToDeviceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.to(val, device='cpu', dtype=torch.float, non_blocking=False)
|
||||
return torch.ops.aten.to(val,
|
||||
device='cpu',
|
||||
dtype=torch.float,
|
||||
non_blocking=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenToDeviceModule())
|
||||
def AtenToDeviceModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -3018,14 +3156,18 @@ class UpSampleNearest2dBackward(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
output_size=[6, 12],
|
||||
input_size=[1, 1, 2, 3],
|
||||
scales_h=3.0,
|
||||
scales_w=4.0)
|
||||
return torch.ops.aten.upsample_nearest2d_backward(
|
||||
input,
|
||||
output_size=[6, 12],
|
||||
input_size=[1, 1, 2, 3],
|
||||
scales_h=3.0,
|
||||
scales_w=4.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackward())
|
||||
|
@ -3041,16 +3183,22 @@ class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
output_size=[4, 8],
|
||||
input_size=[1, 1, 2, 3],
|
||||
scales_h=None,
|
||||
scales_w=None)
|
||||
return torch.ops.aten.upsample_nearest2d_backward(
|
||||
input,
|
||||
output_size=[4, 8],
|
||||
input_size=[1, 1, 2, 3],
|
||||
scales_h=None,
|
||||
scales_w=None)
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: UpSampleNearest2dBackwardScalesNone())
|
||||
def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 4, 8))
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToIntZeroRank(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -28,9 +30,12 @@ class TensorToIntZeroRank(torch.nn.Module):
|
|||
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToInt(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -47,9 +52,12 @@ class TensorToInt(torch.nn.Module):
|
|||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 1, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToFloatZeroRank(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -66,9 +74,12 @@ class TensorToFloatZeroRank(torch.nn.Module):
|
|||
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand((), dtype=torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToFloat(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -85,9 +96,12 @@ class TensorToFloat(torch.nn.Module):
|
|||
def TensorToFloat_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand((1, 1), dtype=torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToBoolZeroRank(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -104,9 +118,12 @@ class TensorToBoolZeroRank(torch.nn.Module):
|
|||
def TensorToBoolZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(1, dtype=torch.bool))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorToBool(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -357,7 +357,10 @@ class EmptyLikeMemoryFormatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a,
|
||||
|
@ -396,7 +399,8 @@ class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a, dtype=torch.float64,
|
||||
|
@ -476,7 +480,8 @@ class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
|
||||
|
@ -555,7 +560,8 @@ class OnesLikeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
|
||||
|
@ -596,7 +602,8 @@ class NewZerosModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64)
|
||||
|
@ -634,7 +641,8 @@ class NewZerosModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32)
|
||||
|
@ -715,7 +723,8 @@ class NewOnesModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64)
|
||||
|
@ -753,7 +762,8 @@ class NewOnesModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32)
|
||||
|
@ -967,7 +977,8 @@ class FullLikeModuleInt3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 5.0, dtype=torch.int64)
|
||||
|
@ -1024,7 +1035,8 @@ class FullLikeModuleFloat3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 15, dtype=torch.float32)
|
||||
|
@ -1166,7 +1178,8 @@ class NewEmptyModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0)
|
||||
|
@ -1205,7 +1218,8 @@ class NewEmptyModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4],
|
||||
|
|
|
@ -13,14 +13,15 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TorchPrimLoopForLikeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
|
@ -28,29 +29,30 @@ class TorchPrimLoopForLikeModule(torch.nn.Module):
|
|||
for i in range(x_val):
|
||||
sum += i
|
||||
return sum
|
||||
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule())
|
||||
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(6, 8, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
sum = 0
|
||||
while(x_val > sum):
|
||||
while (x_val > sum):
|
||||
sum += 1
|
||||
return sum
|
||||
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule())
|
||||
def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
|
||||
|
|
|
@ -22,7 +22,10 @@ class Conv2dNoPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -45,7 +48,10 @@ class Conv2dBiasNoPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -68,7 +74,10 @@ class Conv2dWithPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -97,7 +106,10 @@ class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -142,15 +154,23 @@ def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Convolution2DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -163,12 +183,14 @@ class Convolution2DModule(torch.nn.Module):
|
|||
output_padding=[0, 0],
|
||||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Convolution2DModule())
|
||||
def Convolution2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class Convolution2DStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -189,19 +211,28 @@ class Convolution2DStaticModule(torch.nn.Module):
|
|||
output_padding=[0, 0],
|
||||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Convolution2DStaticModule())
|
||||
def Convolution2DStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class Convolution2DStridedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -214,164 +245,218 @@ class Convolution2DStridedModule(torch.nn.Module):
|
|||
output_padding=[0, 0],
|
||||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Convolution2DStridedModule())
|
||||
def Convolution2DStridedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _Convolution2DAllFalseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=False)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DAllFalseModule())
|
||||
def _Convolution2DAllFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _Convolution2DBenchmarkModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=True,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=False)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=True,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DBenchmarkModule())
|
||||
def _Convolution2DBenchmarkModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _Convolution2DDeterministicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=True,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=False)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=True,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DDeterministicModule())
|
||||
def _Convolution2DDeterministicModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _Convolution2DCudnnModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=True,
|
||||
allow_tf32=False)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=True,
|
||||
allow_tf32=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DCudnnModule())
|
||||
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _Convolution2DTF32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=True)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[3, 3],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
benchmark=False,
|
||||
deterministic=False,
|
||||
cudnn_enabled=False,
|
||||
allow_tf32=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DTF32Module())
|
||||
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -387,19 +472,29 @@ class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
|||
deterministic=False,
|
||||
cudnn_enabled=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
|
||||
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -415,19 +510,29 @@ class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
|||
deterministic=False,
|
||||
cudnn_enabled=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
|
||||
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -443,19 +548,29 @@ class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
|||
deterministic=True,
|
||||
cudnn_enabled=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
|
||||
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -471,19 +586,29 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
|||
deterministic=False,
|
||||
cudnn_enabled=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
|
||||
def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class ConvolutionModule2DGroups(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -496,12 +621,15 @@ class ConvolutionModule2DGroups(torch.nn.Module):
|
|||
output_padding=[0, 0],
|
||||
groups=4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DGroups())
|
||||
def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ConvolutionModule2DTranspose(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -510,8 +638,14 @@ class ConvolutionModule2DTranspose(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -529,6 +663,7 @@ class ConvolutionModule2DTranspose(torch.nn.Module):
|
|||
def ConvolutionModule2DTranspose_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2))
|
||||
|
||||
|
||||
class ConvolutionModule2DTransposeStrided(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -537,8 +672,14 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -552,10 +693,12 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module):
|
|||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ConvolutionModule2DTransposeStrided())
|
||||
def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
|
||||
|
||||
|
||||
class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -579,7 +722,8 @@ class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module):
|
|||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ConvolutionModule2DTransposeStridedStatic())
|
||||
def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
|
||||
|
||||
|
@ -592,8 +736,14 @@ class Conv_Transpose2dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose2d(inputVec,
|
||||
|
@ -619,19 +769,23 @@ class UpSampleNearest2d(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d(input,
|
||||
output_size=[18, 48],
|
||||
scales_h=3.0,
|
||||
scales_w=4.0)
|
||||
output_size=[18, 48],
|
||||
scales_h=3.0,
|
||||
scales_w=4.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2d())
|
||||
def UpSampleNearest2d_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
|
||||
|
||||
|
||||
class UpSampleNearest2dSameSize(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -640,7 +794,10 @@ class UpSampleNearest2dSameSize(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
|
@ -660,7 +817,13 @@ class UpSampleNearest2dDiffSize(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True)
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[8, 11],
|
||||
|
@ -679,7 +842,13 @@ class UpSampleNearest2dDiffFactor(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True)
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[6, 10],
|
||||
|
@ -700,7 +869,10 @@ class UpSampleNearest2dSameFactor(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
|
|
|
@ -17,7 +17,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
# the PyTorch op registry permanently.
|
||||
import torch_mlir._torch_mlir_custom_op_example
|
||||
|
||||
|
||||
class CustomOpExampleModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -118,7 +118,8 @@ class ElementwiseTernaryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
@ -152,7 +153,8 @@ class ElementwiseAtenWhereSelfModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule())
|
||||
def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), torch.rand(1, 12, 5, 5), torch.rand(()))
|
||||
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool),
|
||||
torch.rand(1, 12, 5, 5), torch.rand(()))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -166,7 +168,8 @@ class ElementwiseWhereSelfModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
@ -190,7 +193,8 @@ class ElementwiseWhereScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.where(a > 0.5, 4.0, 8.0)
|
||||
|
@ -212,7 +216,8 @@ class ElementwiseWhereScalarOtherModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -235,7 +240,8 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -913,7 +919,8 @@ class ElementwiseAtan2TensorIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
|
||||
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10))
|
||||
tu.randint(4, low=1, high=10).type(torch.int32),
|
||||
tu.randint(4, low=1, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -936,8 +943,9 @@ class ElementwiseAtan2FloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
|
||||
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 4, low=1, high=10).to(torch.int32),
|
||||
tu.rand(4, 4).double())
|
||||
module.forward(
|
||||
tu.randint(4, 4, low=1, high=10).to(torch.int32),
|
||||
tu.rand(4, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -983,6 +991,7 @@ class ElementwiseLogIntModule(torch.nn.Module):
|
|||
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -1200,7 +1209,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
|
|||
return torch.pow(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwisePowTensorBroadcastModule())
|
||||
def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1), tu.rand(3, 4))
|
||||
|
||||
|
@ -1214,7 +1224,10 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -1233,7 +1246,10 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32, False, False)
|
||||
|
||||
|
@ -1342,7 +1358,8 @@ class ElementwiseAbsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.abs(a)
|
||||
|
@ -1418,6 +1435,7 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
|||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -1435,7 +1453,8 @@ class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
|
|||
return torch.remainder(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
|
||||
def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, high=10).to(torch.int32))
|
||||
|
||||
|
@ -1457,13 +1476,15 @@ class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
|||
return torch.remainder(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Float())
|
||||
def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(10, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1478,12 +1499,15 @@ class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
|||
return torch.remainder(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Int())
|
||||
def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 2, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1498,7 +1522,8 @@ class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
|||
return torch.remainder(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseRemainderScalarModule_Bool())
|
||||
def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([True, False, True, True, True]))
|
||||
|
||||
|
@ -1786,7 +1811,8 @@ class ElementwiseCloneModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.clone(x)
|
||||
|
@ -1808,7 +1834,8 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.clone(x, memory_format=torch.contiguous_format)
|
||||
|
@ -1830,7 +1857,8 @@ class LiftFreshCopyModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.lift_fresh_copy(x)
|
||||
|
@ -2038,9 +2066,12 @@ class ElementwiseNegModule(torch.nn.Module):
|
|||
def ElementwiseNegModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -2053,11 +2084,14 @@ class ElementwiseAtenLogicalOrOpModule(torch.nn.Module):
|
|||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpModule())
|
||||
def ElementwiseAtenLogicalOrOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([False, True]), torch.tensor([False, False]))
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -2070,13 +2104,18 @@ class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module):
|
|||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module())
|
||||
def ElementwiseAtenLogicalOrOpDiffArgs1Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([0.2, 0.1]), torch.tensor([0, 1]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -2089,13 +2128,18 @@ class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module):
|
|||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module())
|
||||
def ElementwiseAtenLogicalOrOpDiffArgs2Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([True, False]), torch.tensor([0, 1]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -2108,83 +2152,125 @@ class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module):
|
|||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module())
|
||||
def ElementwiseAtenLogicalOrOpDiffArgs3Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([1, 2]), torch.tensor([False, True]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.int64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule())
|
||||
def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100))
|
||||
module.forward(tu.randint(2, 3, 4, 5, low=3, high=10),
|
||||
tu.randint(2, 3, 4, 5, low=10, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule())
|
||||
def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(2, 3, 3, 5), torch.rand(2, 3, 3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.int64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule())
|
||||
def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100)))
|
||||
module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)),
|
||||
torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule())
|
||||
def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, high=3), tu.randint(4, 3, high=3))
|
||||
|
||||
|
@ -2244,7 +2330,10 @@ class AtenTriuModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x)
|
||||
|
@ -2266,7 +2355,8 @@ class AtenTriuWithPosDiagonalModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x, diagonal=2)
|
||||
|
@ -2288,7 +2378,10 @@ class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x, diagonal=-4)
|
||||
|
@ -2298,6 +2391,7 @@ class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
|||
def AtenTriuWithNegDiagonalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 5, 9))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -2317,7 +2411,8 @@ class AtenRoundFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AtenRoundFloatModule())
|
||||
def AtenRoundFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 5, low = -3.0, high = 3.0))
|
||||
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
|
||||
|
||||
|
||||
class AtenRoundIntModule(torch.nn.Module):
|
||||
|
||||
|
@ -2335,7 +2430,7 @@ class AtenRoundIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AtenRoundIntModule())
|
||||
def AtenRoundIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5, 5, low = -10))
|
||||
module.forward(tu.randint(5, 5, low=-10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2349,7 +2444,8 @@ class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3.0)
|
||||
|
@ -2368,7 +2464,8 @@ class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3.0)
|
||||
|
@ -2387,7 +2484,8 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3)
|
||||
|
@ -2409,12 +2507,14 @@ class Fill_TensorFloat32WithFloat32(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor, value):
|
||||
return torch.ops.aten.fill_(tensor, value)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat32())
|
||||
def Fill_TensorFloat32WithFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand())
|
||||
|
@ -2428,12 +2528,14 @@ class Fill_TensorFloat32WithFloat64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor, value):
|
||||
return torch.ops.aten.fill_(tensor, value)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithFloat64())
|
||||
def Fill_TensorFloat32WithFloat64_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand().to(torch.float64))
|
||||
|
@ -2447,12 +2549,14 @@ class Fill_TensorFloat32WithInt64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, value):
|
||||
return torch.ops.aten.fill_(tensor, value)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat32WithInt64())
|
||||
def Fill_TensorFloat32WithInt64_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), tu.randint())
|
||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -28,9 +30,12 @@ class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -47,9 +52,12 @@ class ElementwiseGtIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -66,9 +74,12 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
|||
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -85,9 +96,12 @@ class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -104,9 +118,12 @@ class ElementwiseGeIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -123,9 +140,12 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -142,9 +162,12 @@ class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -162,9 +185,12 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGtIntTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -182,9 +208,12 @@ class ElementwiseGtIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -201,9 +230,12 @@ class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -220,9 +252,12 @@ class ElementwiseLtIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -240,9 +275,12 @@ class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
|||
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -259,9 +297,12 @@ class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeFloatScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -278,9 +319,12 @@ class ElementwiseLeIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -297,9 +341,12 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -316,9 +363,12 @@ class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -336,9 +386,12 @@ class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
|||
def ElementwiseLtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLtIntTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -356,9 +409,12 @@ class ElementwiseLtIntTensorModule(torch.nn.Module):
|
|||
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -376,9 +432,12 @@ def ElementwiseEqFloatScalarModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -395,9 +454,12 @@ class ElementwiseEqIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5, 8, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -415,9 +477,12 @@ class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
|||
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -437,9 +502,12 @@ def ElementwiseEqFloatTensorModule_basic(module, tu: TestUtils):
|
|||
torch.tensor([[1.0, 2.2, 6.0], [6.0, 2.0, 3.1]]).to(torch.float32),
|
||||
torch.tensor([1.0, 2.4, 6.0]).to(torch.float32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseEqIntTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -455,11 +523,16 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
|
||||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5,
|
||||
low=2,
|
||||
high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -477,9 +550,12 @@ def ElementwiseNeFloatTensorModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
torch.tensor([[1.0, 2.2, 2.0], [6.0, 2.0, 3.1]]).to(torch.float32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNeIntScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -496,9 +572,12 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
|
|||
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(8, 5, low=2, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AnyBoolTrueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -517,6 +596,7 @@ def AnyBoolTrueModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class AnyBoolFalseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -533,8 +613,10 @@ class AnyBoolFalseModule(torch.nn.Module):
|
|||
def AnyBoolFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# =================================================================================
|
||||
|
||||
|
||||
class AllBoolTrueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -553,8 +635,10 @@ class AllBoolTrueModule(torch.nn.Module):
|
|||
def AllBoolTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# =================================================================================
|
||||
|
||||
|
||||
class AllBoolFalseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -567,7 +651,8 @@ class AllBoolFalseModule(torch.nn.Module):
|
|||
def forward(self):
|
||||
input = [True, False, True, True, False]
|
||||
return torch.ops.aten.all(input)
|
||||
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AllBoolFalseModule())
|
||||
def AllBoolFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
|
@ -70,9 +70,11 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -84,8 +86,7 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule())
|
||||
def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -178,9 +179,11 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index, ),
|
||||
|
@ -192,8 +195,7 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl3DFloatAccumulateModule())
|
||||
def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -283,9 +285,11 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -296,8 +300,7 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPut3DFloatNonAccumulateModule())
|
||||
def IndexPut3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -359,9 +362,11 @@ class IndexPut3DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -432,9 +437,11 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -444,8 +451,7 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut3DFloatAccumulateModule())
|
||||
def IndexPut3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -507,9 +513,11 @@ class IndexPut3DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -583,9 +591,11 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -596,8 +606,7 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule())
|
||||
def IndexPutHackedTwin3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -661,9 +670,11 @@ class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -733,9 +744,11 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
@ -744,8 +757,7 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule())
|
||||
def IndexPutHackedTwin3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -805,9 +817,11 @@ class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class IndexSelectSingleIdxModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -22,16 +23,17 @@ class IndexSelectSingleIdxModule(torch.nn.Module):
|
|||
([4, 5, 6], torch.float32, True),
|
||||
([1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
|
||||
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))
|
||||
|
||||
|
||||
class IndexSelectTwoIdxModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -41,16 +43,17 @@ class IndexSelectTwoIdxModule(torch.nn.Module):
|
|||
([4, 5, 6], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
|
||||
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))
|
||||
|
||||
|
||||
class IndexSelectWholeDimensionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -60,16 +63,17 @@ class IndexSelectWholeDimensionModule(torch.nn.Module):
|
|||
([4, 5, 6], torch.float32, True),
|
||||
([4], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 0, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
|
||||
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
||||
|
||||
|
||||
class IndexSelectWholeTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -79,54 +83,59 @@ class IndexSelectWholeTensorModule(torch.nn.Module):
|
|||
([3], torch.float32, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 0, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
|
||||
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))
|
||||
|
||||
|
||||
class IndexSelectDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
|
||||
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))
|
||||
|
||||
|
||||
class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
|
||||
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))
|
||||
|
||||
|
||||
class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -136,10 +145,10 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
|||
([4, 5, 6], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
|
||||
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))
|
||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulDot(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -28,10 +30,13 @@ class MatmulDot(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MatmulDot())
|
||||
def Matmul_dot(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -48,10 +53,13 @@ class Matmul2D(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Matmul2D())
|
||||
def Matmul_2d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(4, 5))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulVecMat(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -68,10 +76,13 @@ class MatmulVecMat(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MatmulVecMat())
|
||||
def Matmul_vecmat(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4, 5))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulMatVec(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -88,18 +99,23 @@ class MatmulMatVec(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MatmulMatVec())
|
||||
def Matmul_matvec(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul3D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -108,18 +124,27 @@ class Matmul3D(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Matmul3D())
|
||||
def Matmul_3d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul4d(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -128,10 +153,13 @@ class Matmul4d(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Matmul4d())
|
||||
def Matmul_4d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Matmul4dStatic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -149,9 +177,12 @@ class Matmul4dStatic(torch.nn.Module):
|
|||
def Matmul4dStatic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulStaticBroadcast(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -169,17 +200,22 @@ class MatmulStaticBroadcast(torch.nn.Module):
|
|||
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([4, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([4, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -188,18 +224,23 @@ class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MatmulSingleDynamicBatchDim())
|
||||
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MatmulBroadcastBatchDim(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([4, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -208,9 +249,11 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MatmulBroadcastBatchDim())
|
||||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Mv(torch.nn.Module):
|
||||
|
||||
@export
|
||||
|
@ -225,4 +268,4 @@ class Mv(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Mv())
|
||||
def Mv_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
||||
|
|
|
@ -14,13 +14,16 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# Multi-layer perceptron (MLP) models.
|
||||
|
||||
|
||||
class Mlp1LayerModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
torch.manual_seed(0)
|
||||
self.fc0 = nn.Linear(3, 5)
|
||||
self.tanh0 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -29,11 +32,14 @@ class Mlp1LayerModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp1LayerModule())
|
||||
def Mlp1LayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
|
||||
class Mlp2LayerModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -43,6 +49,7 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
self.tanh0 = nn.Tanh()
|
||||
self.fc1 = nn.Linear(N_HIDDEN, 2)
|
||||
self.tanh1 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -53,11 +60,14 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
x = self.tanh1(self.fc1(x))
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
||||
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
|
||||
class Mlp2LayerModuleNoBias(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -67,6 +77,7 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
|
|||
self.tanh0 = nn.Tanh()
|
||||
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
|
||||
self.tanh1 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -77,25 +88,31 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
|
|||
x = self.tanh1(self.fc1(x))
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
|
||||
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
|
||||
class BatchMlpLayerModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
torch.manual_seed(0)
|
||||
self.fc0 = nn.Linear(3, 5)
|
||||
self.tanh0 = nn.Tanh()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
|
||||
def BatchMlpLayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(7, 5, 3))
|
||||
|
|
|
@ -14,505 +14,510 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
class NllLossModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule())
|
||||
def NllLossModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_mean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=2)[0]
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||
def NllLossModule_mean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_sum(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=2)[0]
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||
def NllLossModule_sum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_1D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||
def NllLossModule_1D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.randint(high=3))
|
||||
module.forward(tu.rand(3), tu.randint(high=3))
|
||||
|
||||
|
||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# None of the index is ignored here, since the ignored index is out of bounds.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10)[0]
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
# None of the index is ignored here, since the ignored index is out of bounds.
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.nll_loss_forward(x,
|
||||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
@register_test_case(
|
||||
module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_backward(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward())
|
||||
def NllLossModuleBackward_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
||||
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward_ignore_index(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: NllLossModule_backward_ignore_index())
|
||||
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardMean(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMean())
|
||||
def NllLossModuleBackwardMean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
||||
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardSum(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSum())
|
||||
def NllLossModuleBackwardSum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardSumWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
||||
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1D())
|
||||
def NllLossModuleBackward1D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=0,
|
||||
ignore_index=10,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
||||
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DMean(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMean())
|
||||
def NllLossModuleBackward1DMean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=1,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||
@register_test_case(
|
||||
module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DSum(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSum())
|
||||
def NllLossModuleBackward1DSum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
return torch.ops.aten.nll_loss_backward(grad_output,
|
||||
input,
|
||||
target=target,
|
||||
weight=weight,
|
||||
reduction=2,
|
||||
ignore_index=1,
|
||||
total_weight=total_weight)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
||||
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm1DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn1d = torch.nn.BatchNorm1d(4)
|
||||
|
@ -35,9 +37,12 @@ class BatchNorm1DModule(torch.nn.Module):
|
|||
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn1d = torch.nn.BatchNorm1d(4)
|
||||
|
@ -61,9 +66,12 @@ class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
|||
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm2DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn2d = torch.nn.BatchNorm2d(2)
|
||||
|
@ -86,9 +94,12 @@ class BatchNorm2DModule(torch.nn.Module):
|
|||
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 2, 3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BatchNorm3DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn3d = torch.nn.BatchNorm3d(5)
|
||||
|
@ -113,111 +124,156 @@ class BatchNorm3DModule(torch.nn.Module):
|
|||
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 3, 6, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNorm1DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
return torch.ops.aten.native_batch_norm(x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNorm1DModule())
|
||||
def NativeBatchNorm1DModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
module.forward(tu.rand(2, 5, 3), tu.rand(5), tu.rand(5), tu.rand(5),
|
||||
tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNorm2DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
return torch.ops.aten.native_batch_norm(x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNorm2DModule())
|
||||
def NativeBatchNorm2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
module.forward(tu.rand(2, 5, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5),
|
||||
tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNorm3DModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, weight, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
return torch.ops.aten.native_batch_norm(x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNorm3DModule())
|
||||
def NativeBatchNorm3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5),
|
||||
tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeBatchNormNoneWeightModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
x, None, bias, running_mean, running_var, training=False,
|
||||
momentum=0.1, eps=0.00001)
|
||||
return torch.ops.aten.native_batch_norm(x,
|
||||
None,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=0.00001)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeBatchNormNoneWeightModule())
|
||||
def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(5), tu.rand(5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeLayerNormModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -230,38 +286,46 @@ class NativeLayerNormModule(torch.nn.Module):
|
|||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)
|
||||
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
||||
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
class NativeLayerNormDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)
|
||||
return torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
|
||||
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeLayerNormModule4D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -274,17 +338,20 @@ class NativeLayerNormModule4D(torch.nn.Module):
|
|||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)[0]
|
||||
return torch.ops.aten.native_layer_norm(x, list, weight, bias,
|
||||
eps=0.5)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule4D())
|
||||
def NativeLayerNormModule4D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||
|
@ -309,9 +376,12 @@ class LayerNormModule(torch.nn.Module):
|
|||
def LayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormLastDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([3])
|
||||
|
@ -332,9 +402,12 @@ class LayerNormLastDimModule(torch.nn.Module):
|
|||
def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ly = torch.nn.LayerNorm([2, 2, 3])
|
||||
|
@ -355,6 +428,7 @@ class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
|||
return self.ly(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: LayerNormNormalizeOverAllDimsModule())
|
||||
def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 3))
|
||||
|
|
|
@ -43,7 +43,10 @@ class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
@ -86,7 +89,10 @@ class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
@ -113,7 +119,10 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
@ -160,7 +169,10 @@ class MaxPool2dCeilModeTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
@ -182,7 +194,10 @@ class MaxPool2dWithIndicesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -205,7 +220,10 @@ class MaxPool2dWithIndicesFullSizeKernelModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -229,7 +247,10 @@ class MaxPool2dWithIndicesNonDefaultPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -253,7 +274,10 @@ class MaxPool2dWithIndicesNonDefaultStrideModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -277,7 +301,10 @@ class MaxPool2dWithIndicesNonDefaultDilationModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -301,7 +328,10 @@ class MaxPool2dWithIndicesNonDefaultParamsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -325,7 +355,10 @@ class MaxPool2dWithIndicesAllNegativeValuesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -372,7 +405,10 @@ class MaxPool2dWithIndicesAllOnesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -395,7 +431,10 @@ class MaxPool2dWithIndicesCeilModeTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -483,9 +522,18 @@ class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.int64, True),
|
||||
])
|
||||
def forward(self, output, input, indices):
|
||||
kernel_size = [2, 2]
|
||||
|
@ -513,9 +561,12 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, output, input, indices):
|
||||
kernel_size = [2, 2]
|
||||
|
@ -552,15 +603,20 @@ class AvgPool2dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dFloatModule())
|
||||
def AvgPool2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, 20, low=-1))
|
||||
|
||||
|
||||
class AvgPool2dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -575,7 +631,10 @@ class AvgPool2dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
@ -650,11 +709,15 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule())
|
||||
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class QuantizedMLP(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.random.manual_seed(0)
|
||||
|
|
|
@ -11,14 +11,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -28,16 +31,20 @@ class ReduceSumFloatModule(torch.nn.Module):
|
|||
def ReduceSumFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDtypeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dtype=torch.float32)
|
||||
|
@ -47,16 +54,20 @@ class ReduceSumDtypeFloatModule(torch.nn.Module):
|
|||
def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumElementTypeBoolModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -66,16 +77,20 @@ class ReduceSumElementTypeBoolModule(torch.nn.Module):
|
|||
def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1))
|
||||
|
@ -85,47 +100,60 @@ class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
|||
def ReduceSumDimIntListFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListDtypeFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1), dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeFloatModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ReduceSumDimIntListDtypeFloatModule())
|
||||
def ReduceSumDimIntListDtypeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListKeepDimFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (1, 2), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimFloatModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ReduceSumDimIntListKeepDimFloatModule())
|
||||
def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -138,20 +166,26 @@ class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module):
|
|||
return torch.sum(a, dim=(-1), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule())
|
||||
def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(
|
||||
module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule())
|
||||
def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 12, 7, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dim=[])
|
||||
|
@ -161,9 +195,12 @@ class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
|
|||
def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -176,20 +213,25 @@ class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
|
|||
return torch.sum(a, dim=(-1), keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
|
||||
def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumUnsignedIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -199,16 +241,20 @@ class ReduceSumUnsignedIntModule(torch.nn.Module):
|
|||
def ReduceSumUnsignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, low=0, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumSignedIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -218,16 +264,20 @@ class ReduceSumSignedIntModule(torch.nn.Module):
|
|||
def ReduceSumSignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDtypeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dtype=torch.int64)
|
||||
|
@ -237,16 +287,20 @@ class ReduceSumDtypeIntModule(torch.nn.Module):
|
|||
def ReduceSumDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1))
|
||||
|
@ -256,16 +310,20 @@ class ReduceSumDimIntListIntModule(torch.nn.Module):
|
|||
def ReduceSumDimIntListIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListDtypeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1), dtype=torch.int64)
|
||||
|
@ -275,35 +333,44 @@ class ReduceSumDimIntListDtypeIntModule(torch.nn.Module):
|
|||
def ReduceSumDimIntListDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceSumDimIntListKeepDimIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (1, 2), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimIntModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: ReduceSumDimIntListKeepDimIntModule())
|
||||
def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxAlongDim(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1)[0]
|
||||
|
@ -313,16 +380,20 @@ class ReduceMaxAlongDim(torch.nn.Module):
|
|||
def ReduceMaxAlongDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxAlongDimNegative(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1)[0]
|
||||
|
@ -332,16 +403,20 @@ class ReduceMaxAlongDimNegative(torch.nn.Module):
|
|||
def ReduceMaxAlongDimNegative_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=10).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxKeepDim(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1, keepdim=True)[1]
|
||||
|
@ -351,311 +426,385 @@ class ReduceMaxKeepDim(torch.nn.Module):
|
|||
def ReduceMaxKeepDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxKeepDimReturnBoth(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxKeepDimReturnBoth())
|
||||
def ReduceMaxKeepDimReturnBoth_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxAllDims(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxAllDims())
|
||||
def ReduceMaxAllDims_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxNegativeDim(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, -1, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxNegativeDim())
|
||||
def ReduceMaxNegativeDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxFloatModule())
|
||||
def ReduceMaxFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxSignedIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxSignedIntModule())
|
||||
def ReduceMaxSignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceMaxUnsignedIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule())
|
||||
def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceL1NormModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL1NormModule())
|
||||
def ReduceL1NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceL1NormWithDTypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL1NormWithDTypeModule())
|
||||
def ReduceL1NormWithDTypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceL2NormModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL2NormModule())
|
||||
def ReduceL2NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceLN3NormModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=-3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceLN3NormModule())
|
||||
def ReduceLN3NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceL3NormAllDimsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=None, ord=3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule())
|
||||
def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReduceL3NormKeepDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, keepdim=True, ord=3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule())
|
||||
def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class ReduceFrobeniusNormModule(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceFrobeniusNormModule())
|
||||
def ReduceFrobeniusNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class ReduceFrobeniusNormKeepDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule())
|
||||
def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MseLossNoReductionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.mse_loss(x, y, reduction=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MseLossNoReductionModule())
|
||||
def MseLossNoReductionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4), tu.rand(2, 4))
|
||||
|
||||
|
||||
class MseLossMeanReductionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.mse_loss(x, y, reduction=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MseLossMeanReductionModule())
|
||||
def MseLossMeanReductionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4), tu.rand(2, 4))
|
||||
|
||||
|
||||
class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
])
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.mse_loss(x, y, reduction=2)
|
||||
|
||||
@register_test_case(module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule())
|
||||
def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils):
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule())
|
||||
def MseLossSumReductionWithDifferentElemTypeModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4), tu.rand(2, 4).to(torch.float64))
|
||||
|
|
|
@ -10,7 +10,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -19,17 +21,20 @@ class ViewExpandModule(torch.nn.Module):
|
|||
None,
|
||||
([6, 4], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3, 4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandModule())
|
||||
def ViewExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandOnesModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -38,17 +43,20 @@ class ViewExpandOnesModule(torch.nn.Module):
|
|||
None,
|
||||
([1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(1, 1, 1, 1, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesModule())
|
||||
def ViewExpandOnesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -57,17 +65,21 @@ class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 1, 16, 1, 1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(1, 2, 1, 16, 1, 1, 1, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ViewExpandOnesBeforeAndAfterModule())
|
||||
def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 1, 16, 1, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandOnesMiddleModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -76,17 +88,19 @@ class ViewExpandOnesMiddleModule(torch.nn.Module):
|
|||
None,
|
||||
([3, 1, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(3, 1, 1, 1, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesMiddleModule())
|
||||
def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewCollapseOnesMiddleModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -95,55 +109,67 @@ class ViewCollapseOnesMiddleModule(torch.nn.Module):
|
|||
None,
|
||||
([3, 1, 1, 1, 1, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(3, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseOnesMiddleModule())
|
||||
def ViewCollapseOnesMiddleModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1, 1, 1, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewDynamicExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 30, 384], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, 30,
|
||||
384], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 4, 5, 6, 12, 32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDynamicExpandModule())
|
||||
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 30, 384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(a.size(0), a.size(1), 12, 32)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule())
|
||||
def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewCollapseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -152,38 +178,49 @@ class ViewCollapseModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(8)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseModule())
|
||||
def ViewCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, a, b, c):
|
||||
return a.view(a.size(0), int(b), int(c), a.size(3), 384)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
|
||||
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
|
||||
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3),
|
||||
torch.tensor(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandCollapseWithOnesModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -192,17 +229,20 @@ class ViewExpandCollapseWithOnesModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 4, 8, 8], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 1, 1, 4, 64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandCollapseWithOnesModule())
|
||||
def ViewExpandCollapseWithOnesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 8, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandCollapseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -211,55 +251,69 @@ class ViewExpandCollapseModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 4, 8, 16, 4], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(8, 2, 4, 16, 2, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandCollapseModule())
|
||||
def ViewExpandCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 8, 16, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewDynamicExpandCollapseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 4, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 4, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 1, 4, 64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseModule())
|
||||
def ViewDynamicExpandCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 8, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewDynamicExpandCollapseWithAtenIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 1, a.size(1), 64)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule())
|
||||
def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 8, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewTwoToThreeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -268,17 +322,20 @@ class ViewTwoToThreeStaticModule(torch.nn.Module):
|
|||
None,
|
||||
([3, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewTwoToThreeStaticModule())
|
||||
def ViewTwoToThreeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewTwoFiveThreeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -287,17 +344,20 @@ class ViewTwoFiveThreeStaticModule(torch.nn.Module):
|
|||
None,
|
||||
([3, 5, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 5, 3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewTwoFiveThreeStaticModule())
|
||||
def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewOffsetTestStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -306,17 +366,20 @@ class ViewOffsetTestStaticModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 3, 2, 2, 5, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3, 4, 6, 5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewOffsetTestStaticModule())
|
||||
def ViewOffsetTestStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 2, 2, 5, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewOffsetBackwardTestStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -325,17 +388,21 @@ class ViewOffsetBackwardTestStaticModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 3, 4, 5, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(2, 3, 2, 2, 6, 5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewOffsetBackwardTestStaticModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ViewOffsetBackwardTestStaticModule())
|
||||
def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4, 5, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class View1DFoldModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -344,17 +411,20 @@ class View1DFoldModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(-9223372036854775808)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: View1DFoldModule())
|
||||
def View1DFoldModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewCollapseInferredDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -363,17 +433,20 @@ class ViewCollapseInferredDimModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 3, 4], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(-9223372036854775808, 4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseInferredDimModule())
|
||||
def ViewCollapseInferredDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandInferredDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -382,17 +455,20 @@ class ViewExpandInferredDimModule(torch.nn.Module):
|
|||
None,
|
||||
([2, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(3, -9223372036854775808, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
|
||||
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewExpandDynamicDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -401,17 +477,20 @@ class ViewExpandDynamicDimModule(torch.nn.Module):
|
|||
None,
|
||||
([1, -9223372036854775808, 128], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(16, 1, 128)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandDynamicDimModule())
|
||||
def ViewExpandDynamicDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 16, 128))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewFlattenAndExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -420,17 +499,20 @@ class ViewFlattenAndExpandModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(a.size(0), a.size(1))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewFlattenAndExpandModule())
|
||||
def ViewFlattenAndExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(64,128))
|
||||
module.forward(tu.rand(64, 128))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnsafeViewExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -439,55 +521,67 @@ class UnsafeViewExpandModule(torch.nn.Module):
|
|||
None,
|
||||
([6, 4], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._unsafe_view(a, [2, 3, 4])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeViewExpandModule())
|
||||
def UnsafeViewExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnsafeViewDynamicExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 30, 384], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, 30,
|
||||
384], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._unsafe_view(a,[2, 4, 5, 6, 12, 32])
|
||||
return torch.ops.aten._unsafe_view(a, [2, 4, 5, 6, 12, 32])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandModule())
|
||||
def UnsafeViewDynamicExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 30, 384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnsafeViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._unsafe_view(a, [a.size(0), a.size(1), 12, 32])
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule())
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: UnsafeViewDynamicExpandWithAtenSizeIntModule())
|
||||
def UnsafeViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnsafeViewCollapseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -496,38 +590,52 @@ class UnsafeViewCollapseModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._unsafe_view(a,[8])
|
||||
return torch.ops.aten._unsafe_view(a, [8])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeViewCollapseModule())
|
||||
def UnsafeViewCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnsafeViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808
|
||||
], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, a, b, c):
|
||||
return torch.ops.aten._unsafe_view(a, [a.size(0), int(b), int(c), a.size(3), 384])
|
||||
return torch.ops.aten._unsafe_view(
|
||||
a, [a.size(0), int(b), int(c),
|
||||
a.size(3), 384])
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule())
|
||||
def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3),
|
||||
torch.tensor(5))
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeViewCollapseDynamicWithAtenSizeIntModule())
|
||||
def UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UnsafeView1DFoldModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -536,17 +644,20 @@ class UnsafeView1DFoldModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._unsafe_view(a, [-9223372036854775808])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeView1DFoldModule())
|
||||
def UnsafeView1DFoldModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReshapeExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -555,17 +666,20 @@ class ReshapeExpandModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.reshape(12, 32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeExpandModule())
|
||||
def ReshapeExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReshapeCollapseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -574,17 +688,20 @@ class ReshapeCollapseModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.reshape(a, (-1,))
|
||||
return torch.reshape(a, (-1, ))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeCollapseModule())
|
||||
def ReshapeCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ViewNoChange1dModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -593,16 +710,17 @@ class ViewNoChange1dModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewNoChange1dModule())
|
||||
def ViewNoChange1dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6))
|
||||
|
||||
|
||||
class ViewNoChange2dModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -611,34 +729,37 @@ class ViewNoChange2dModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(5, 6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewNoChange2dModule())
|
||||
def ViewNoChange2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 6))
|
||||
|
||||
|
||||
class ViewNoChange3dModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(4, 5, 6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewNoChange3dModule())
|
||||
def ViewNoChange3dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6))
|
||||
|
||||
|
||||
class ViewNoChangeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -647,17 +768,20 @@ class ViewNoChangeStaticModule(torch.nn.Module):
|
|||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(4, 5, 6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewNoChangeStaticModule())
|
||||
def ViewNoChangeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReshapeAliasExpandModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -668,17 +792,20 @@ class ReshapeAliasExpandModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._reshape_alias(a, size=(12, 32), stride=(32, 1))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeAliasExpandModule())
|
||||
def ReshapeAliasExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -687,10 +814,10 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._reshape_alias(a, (8,), (1,))
|
||||
return torch.ops.aten._reshape_alias(a, (8, ), (1, ))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
|
||||
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UniformModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -14,9 +15,12 @@ class UniformModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
|
@ -42,8 +46,10 @@ def UniformModule_basic(module, tu: TestUtils):
|
|||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class UniformStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -80,18 +86,24 @@ def UniformStaticModule_basic(module, tu: TestUtils):
|
|||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.bernoulli(x)
|
||||
|
@ -107,7 +119,7 @@ class BernoulliModule(torch.nn.Module):
|
|||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||
|
@ -117,9 +129,12 @@ def BernoulliModule_basic(module, tu: TestUtils):
|
|||
tu.rand(1024, 2048, 4).double(),
|
||||
tu.rand(1024, 256, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliZerosModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -136,9 +151,12 @@ class BernoulliZerosModule(torch.nn.Module):
|
|||
def BernoulliZerosModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(4, 8).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliOnesModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -155,18 +173,24 @@ class BernoulliOnesModule(torch.nn.Module):
|
|||
def BernoulliOnesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones(4, 8).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.bernoulli_(x, 0.4)
|
||||
|
@ -182,7 +206,7 @@ class BernoulliFloatModule(torch.nn.Module):
|
|||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||
|
@ -192,21 +216,30 @@ def BernoulliFloatModule_basic(module, tu: TestUtils):
|
|||
tu.rand(1024, 2048, 4).double(),
|
||||
tu.rand(1024, 512, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class BernoulliTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, px, y, py, z, pz):
|
||||
a = torch.ops.aten.bernoulli_(x, px)
|
||||
|
@ -222,7 +255,7 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||
|
@ -235,9 +268,12 @@ def BernoulliTensorModule_basic(module, tu: TestUtils):
|
|||
tu.rand(1024, 512, 8).double(),
|
||||
tu.rand(1024, 512, 8).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandLikeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -256,9 +292,12 @@ class RandLikeModule(torch.nn.Module):
|
|||
def RandLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1024).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandLikeDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -277,9 +316,12 @@ class RandLikeDtypeModule(torch.nn.Module):
|
|||
def RandLikeDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 1024).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntLowModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -297,9 +339,12 @@ class RandIntLowModule(torch.nn.Module):
|
|||
def RandIntLowModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandIntLowDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -308,7 +353,10 @@ class RandIntLowDtypeModule(torch.nn.Module):
|
|||
None,
|
||||
])
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64)
|
||||
a = torch.ops.aten.randint(low=1,
|
||||
high=1000,
|
||||
size=[128, 256, 512],
|
||||
dtype=torch.float64)
|
||||
mean = torch.mean(a)
|
||||
return mean
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ class AddIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AddIntModule())
|
||||
def AddIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -52,7 +53,8 @@ class SubIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SubIntModule())
|
||||
def SubIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -98,7 +100,8 @@ class MulIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: MulIntModule())
|
||||
def MulIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -337,9 +340,12 @@ class BoolIntConstantModule(torch.nn.Module):
|
|||
def BoolIntConstantModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -348,10 +354,10 @@ class AtenIntTensorByteDtypeModule(torch.nn.Module):
|
|||
None,
|
||||
([], torch.uint8, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenIntTensorByteDtypeModule())
|
||||
def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.uint8))
|
||||
|
@ -359,7 +365,9 @@ def AtenIntTensorByteDtypeModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -368,10 +376,10 @@ class AtenIntTensorCharDtypeModule(torch.nn.Module):
|
|||
None,
|
||||
([], torch.int8, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule())
|
||||
def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8))
|
||||
|
|
|
@ -29,7 +29,8 @@ class NeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NeIntModule())
|
||||
def NeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -52,7 +53,8 @@ class EqIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EqIntModule())
|
||||
def EqIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -75,7 +77,8 @@ class GtIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GtIntModule())
|
||||
def GtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -98,7 +101,8 @@ class GeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GeIntModule())
|
||||
def GeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.randint(low=-100, high=100),
|
||||
tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -11,14 +11,17 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0:5:1, 1:3:1, 2:4:1]
|
||||
|
@ -26,12 +29,14 @@ class SliceModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SliceModule())
|
||||
def SliceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -46,125 +51,148 @@ class SliceStaticModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SliceStaticModule())
|
||||
def SliceStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
result = x[:8, :5, 8:]
|
||||
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
|
||||
return torch.cat((result,cat_tensor), dim=2)
|
||||
result = x[:8, :5, 8:]
|
||||
cat_tensor = torch.ones((6, 4, 1), dtype=torch.float32)
|
||||
return torch.cat((result, cat_tensor), dim=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
|
||||
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:-8,-7:,:]
|
||||
return x[:-8, -7:, :]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
|
||||
def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[-8:3:1, 1:3:1, 2:4:1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
|
||||
def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceEndSleStartModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
result = x[:, 4:3, :]
|
||||
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
|
||||
cat_tensor = torch.ones((6, 1, 7), dtype=torch.float32)
|
||||
return torch.cat((result, cat_tensor), dim=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
|
||||
def SliceEndSleStartModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceStartEqEndModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
result = x[5:5, :, :]
|
||||
cat_tensor = torch.ones((1,4,7), dtype=torch.float32)
|
||||
cat_tensor = torch.ones((1, 4, 7), dtype=torch.float32)
|
||||
return torch.cat((result, cat_tensor), dim=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
|
||||
def SliceStartEqEndModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
module.forward(tu.rand(6, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceSizeTwoStepModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0:5:2, 0:3:2, 0:4:2]
|
||||
|
@ -172,11 +200,14 @@ class SliceSizeTwoStepModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SliceSizeTwoStepModule())
|
||||
def SliceSizeTwoStepModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10,5,17))
|
||||
module.forward(tu.rand(10, 5, 17))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceNegIdxModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -191,11 +222,14 @@ class SliceNegIdxModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SliceNegIdxModule())
|
||||
def SliceNegIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3,9))
|
||||
module.forward(tu.rand(3, 9))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceSingleIdxModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -210,11 +244,14 @@ class SliceSingleIdxModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SliceSingleIdxModule())
|
||||
def SliceSingleIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,8))
|
||||
module.forward(tu.rand(6, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SliceWholeTensorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -229,11 +266,14 @@ class SliceWholeTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SliceWholeTensorModule())
|
||||
def SliceWholeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,8))
|
||||
module.forward(tu.rand(6, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SelectIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -243,18 +283,21 @@ class SelectIntModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.select(0,0)
|
||||
return x.select(0, 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SelectIntModule())
|
||||
def SelectIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(5,5, high=10))
|
||||
module.forward(tu.randint(5, 5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||
class SliceScatterModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -265,13 +308,21 @@ class SliceScatterModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||
return torch.ops.aten.slice_scatter(x,
|
||||
src,
|
||||
dim=1,
|
||||
start=0,
|
||||
end=1,
|
||||
step=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterModule())
|
||||
def SliceScatterModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||
|
||||
|
||||
class SliceScatterZeroDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -282,7 +333,12 @@ class SliceScatterZeroDimModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1)
|
||||
return torch.ops.aten.slice_scatter(x,
|
||||
src,
|
||||
dim=0,
|
||||
start=0,
|
||||
end=1,
|
||||
step=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterZeroDimModule())
|
||||
|
@ -314,7 +370,9 @@ class SliceScatterNegativeDimModule(torch.nn.Module):
|
|||
def SliceScatterNegativeDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(1, 8))
|
||||
|
||||
|
||||
class SliceScatterStepVariationModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -325,14 +383,21 @@ class SliceScatterStepVariationModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2)
|
||||
return torch.ops.aten.slice_scatter(x,
|
||||
src,
|
||||
dim=1,
|
||||
start=0,
|
||||
end=1,
|
||||
step=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterStepVariationModule())
|
||||
def SliceScatterStepVariationModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||
|
||||
|
||||
class SliceScatterStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -343,32 +408,42 @@ class SliceScatterStaticModule(torch.nn.Module):
|
|||
([6, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||
return torch.ops.aten.slice_scatter(x,
|
||||
src,
|
||||
dim=1,
|
||||
start=0,
|
||||
end=1,
|
||||
step=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterStaticModule())
|
||||
def SliceScatterStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||
|
||||
|
||||
class SelectScatterModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0)
|
||||
return torch.ops.aten.select_scatter(x, src, dim=0, index=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SelectScatterModule())
|
||||
def SelectScattertModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(6, 8, 5), torch.rand(8, 5))
|
||||
|
||||
|
||||
class SelectScatterStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -379,43 +454,50 @@ class SelectScatterStaticModule(torch.nn.Module):
|
|||
([6, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.select_scatter(x, src, dim = 1, index = 0)
|
||||
return torch.ops.aten.select_scatter(x, src, dim=1, index=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SelectScatterStaticModule())
|
||||
def SelectScattertStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(6, 8, 5), torch.rand(6, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NarrowHorizontalTest(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.narrow(x, dim=0, start=0, length=2)
|
||||
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NarrowHorizontalTest())
|
||||
def NarrowHorizontalTest_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,3))
|
||||
module.forward(tu.rand(6, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NarrowVerticalTest(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.narrow(x, dim=1, start=0, length=2)
|
||||
|
@ -423,11 +505,14 @@ class NarrowVerticalTest(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NarrowVerticalTest())
|
||||
def NarrowVerticalTest_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,3))
|
||||
module.forward(tu.rand(6, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NarrowHorizontalTest2(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -438,16 +523,18 @@ class NarrowHorizontalTest2(torch.nn.Module):
|
|||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.narrow(x, dim=0, start=0, length=2)
|
||||
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NarrowHorizontalTest2())
|
||||
def NarrowHorizontalTest2_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4))
|
||||
module.forward(tu.rand(6, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NarrowVerticalTest2(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -462,4 +549,4 @@ class NarrowVerticalTest2(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
|
||||
def NarrowVerticalTest2_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4))
|
||||
module.forward(tu.rand(6, 4))
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class SqueezeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -25,8 +26,7 @@ class SqueezeStaticModule(torch.nn.Module):
|
|||
return torch.squeeze(a)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeStaticModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeStaticModule())
|
||||
def SqueezeModule_static(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 7, 1, 3, 1))
|
||||
|
||||
|
@ -35,6 +35,7 @@ def SqueezeModule_static(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqueezeAllUnitDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -47,8 +48,7 @@ class SqueezeAllUnitDimModule(torch.nn.Module):
|
|||
return torch.squeeze(a)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeAllUnitDimModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeAllUnitDimModule())
|
||||
def SqueezeModule_allUnitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1))
|
||||
|
||||
|
@ -57,6 +57,7 @@ def SqueezeModule_allUnitDim(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqueezeBroadcastModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -70,8 +71,7 @@ class SqueezeBroadcastModule(torch.nn.Module):
|
|||
return a * b.squeeze()
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeBroadcastModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeBroadcastModule())
|
||||
def SqueezeModule_broadcast(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3), tu.rand())
|
||||
|
||||
|
@ -80,6 +80,7 @@ def SqueezeModule_broadcast(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqueezeDimStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -92,30 +93,30 @@ class SqueezeDimStaticModule(torch.nn.Module):
|
|||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimStaticModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimStaticModule())
|
||||
def SqueezeDimModule_static(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 7))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SqueezeDimDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 1, 384, -9223372036854775808, 1], torch.float32, True),
|
||||
([-9223372036854775808, 1, 384, -9223372036854775808,
|
||||
1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 4)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimDynamicModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimDynamicModule())
|
||||
def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 1, 384, 12, 1))
|
||||
|
||||
|
@ -124,20 +125,21 @@ def SqueezeDimModule_dynamic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqueezeDimNegDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -9223372036854775808, 1, 384, -9223372036854775808, 1], torch.float32, True),
|
||||
([1, -9223372036854775808, 1, 384, -9223372036854775808,
|
||||
1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, -6)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimNegDimModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimNegDimModule())
|
||||
def SqueezeDimModule_negDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 8, 1, 384, 12, 1))
|
||||
|
||||
|
@ -146,6 +148,7 @@ def SqueezeDimModule_negDim(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqueezeDimIdentityModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -158,8 +161,7 @@ class SqueezeDimIdentityModule(torch.nn.Module):
|
|||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimIdentityModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimIdentityModule())
|
||||
def SqueezeDimModule_identity(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 1, 3))
|
||||
|
||||
|
@ -168,6 +170,7 @@ def SqueezeDimModule_identity(module, tu: TestUtils):
|
|||
|
||||
|
||||
class SqueezeDimUnitDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -180,7 +183,6 @@ class SqueezeDimUnitDimModule(torch.nn.Module):
|
|||
return torch.squeeze(a, 0)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: SqueezeDimUnitDimModule())
|
||||
@register_test_case(module_factory=lambda: SqueezeDimUnitDimModule())
|
||||
def SqueezeDimModule_unitDim(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
|
|
@ -11,7 +11,9 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -28,9 +30,12 @@ class MeanModule(torch.nn.Module):
|
|||
def MeanModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDynamicSizesModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -47,16 +52,20 @@ class MeanDynamicSizesModule(torch.nn.Module):
|
|||
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dtype=torch.float32)
|
||||
|
@ -66,16 +75,22 @@ class MeanDtypeModule(torch.nn.Module):
|
|||
def MeanDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanLargeInputModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x)
|
||||
|
@ -85,16 +100,20 @@ class MeanLargeInputModule(torch.nn.Module):
|
|||
def MeanLargeInputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 128, 1024, low=100, high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 2))
|
||||
|
@ -104,16 +123,22 @@ class MeanDimModule(torch.nn.Module):
|
|||
def MeanDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimLargeInputModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 2))
|
||||
|
@ -126,33 +151,40 @@ def MeanDimLargeInputModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimDtypeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0,), dtype=torch.float32)
|
||||
return torch.ops.aten.mean(x, (0, ), dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimDtypeModule())
|
||||
def MeanDimDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimKeepdimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (1, 2), keepdim=True)
|
||||
|
@ -162,16 +194,20 @@ class MeanDimKeepdimModule(torch.nn.Module):
|
|||
def MeanDimKeepdimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimAllReduceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 1, 2))
|
||||
|
@ -181,16 +217,20 @@ class MeanDimAllReduceModule(torch.nn.Module):
|
|||
def MeanDimAllReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimAllReduceKeepdimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True)
|
||||
|
@ -200,16 +240,20 @@ class MeanDimAllReduceKeepdimModule(torch.nn.Module):
|
|||
def MeanDimAllReduceKeepdimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimNegativeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (-1, 1))
|
||||
|
@ -222,14 +266,17 @@ def MeanDimNegativeModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimEmptyDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dim=[])
|
||||
|
@ -239,16 +286,20 @@ class MeanDimEmptyDimModule(torch.nn.Module):
|
|||
def MeanDimEmptyDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDimNoneDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dim=None)
|
||||
|
@ -258,74 +309,94 @@ class MeanDimNoneDimModule(torch.nn.Module):
|
|||
def MeanDimNoneDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarUnbiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarUnbiasedModule())
|
||||
def VarUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarBiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarBiasedModule())
|
||||
def VarBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdUnbiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdUnbiasedModule())
|
||||
def StdUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdBiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdBiasedModule())
|
||||
def StdBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
@ -342,7 +413,8 @@ class StdDimKeepDimFalseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=(1, 2), keepdim=False)
|
||||
|
@ -364,7 +436,8 @@ class StdDimKeepDimTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=(0, 1, 2), keepdim=True)
|
||||
|
@ -386,7 +459,8 @@ class StdDimBiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=(0, 2), unbiased=False)
|
||||
|
@ -408,7 +482,8 @@ class StdDimEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[], keepdim=False)
|
||||
|
@ -430,7 +505,8 @@ class StdDimNoneDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=None, keepdim=False)
|
||||
|
@ -452,7 +528,8 @@ class VarDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0, 2), keepdim=True)
|
||||
|
@ -474,7 +551,8 @@ class VarDimUnbiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0, 2), unbiased=True, keepdim=True)
|
||||
|
@ -496,10 +574,11 @@ class VarDimBiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True)
|
||||
return torch.ops.aten.var(x, dim=(0, 1), unbiased=False, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarDimBiasedModule())
|
||||
|
@ -518,10 +597,11 @@ class VarDimSingleDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0,), keepdim=True)
|
||||
return torch.ops.aten.var(x, dim=(0, ), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarDimSingleDimModule())
|
||||
|
@ -540,7 +620,8 @@ class VarDimMultiDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[0, 2], keepdim=False)
|
||||
|
@ -562,7 +643,8 @@ class VarDimAllDimReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=True)
|
||||
|
@ -584,7 +666,8 @@ class VarDimNegativeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(-1, 1), keepdim=True)
|
||||
|
@ -606,7 +689,8 @@ class VarDimEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[], keepdim=False)
|
||||
|
@ -628,7 +712,8 @@ class VarDimNoneDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, keepdim=False)
|
||||
|
@ -650,7 +735,8 @@ class VarCorrectionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, correction=2)
|
||||
|
@ -672,13 +758,15 @@ class VarCorrectionSingleDimReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[1], correction=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule())
|
||||
@register_test_case(
|
||||
module_factory=lambda: VarCorrectionSingleDimReduceModule())
|
||||
def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
|
@ -694,7 +782,8 @@ class VarCorrectionAllDimReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x,
|
||||
|
@ -719,7 +808,8 @@ class VarCorrectionKeepDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True)
|
||||
|
@ -741,7 +831,8 @@ class VarCorrectionNoneModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, correction=None)
|
||||
|
@ -763,7 +854,8 @@ class VarCorrectionEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[], correction=2)
|
||||
|
@ -785,7 +877,10 @@ class VarCorrectionLargeInputModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[2, 3], correction=2)
|
||||
|
@ -807,10 +902,16 @@ class VarMeanCorrectionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([
|
||||
-9223372036854775808, -9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808
|
||||
], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True)
|
||||
return torch.ops.aten.var_mean(x,
|
||||
dim=[1, 2],
|
||||
correction=2,
|
||||
keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarMeanCorrectionModule())
|
||||
|
@ -829,10 +930,14 @@ class VarMeanCorrectionNoneModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var_mean(x, dim=None, correction=None, keepdim=False)
|
||||
return torch.ops.aten.var_mean(x,
|
||||
dim=None,
|
||||
correction=None,
|
||||
keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarMeanCorrectionNoneModule())
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class Threshold1dIntI32Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -21,16 +22,17 @@ class Threshold1dIntI32Module(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
||||
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10).to(torch.int32))
|
||||
|
||||
|
||||
class Threshold1dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -39,16 +41,17 @@ class Threshold1dIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
||||
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10))
|
||||
|
||||
|
||||
class Threshold2dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -57,34 +60,37 @@ class Threshold2dIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
||||
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=10))
|
||||
|
||||
|
||||
class Threshold3dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2.2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
||||
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, 6, high=10))
|
||||
|
||||
|
||||
class Threshold1dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -93,16 +99,17 @@ class Threshold1dFloatModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
||||
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4))
|
||||
|
||||
|
||||
class Threshold2dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -111,34 +118,37 @@ class Threshold2dFloatModule(torch.nn.Module):
|
|||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
||||
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5))
|
||||
|
||||
|
||||
class Threshold3dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1.4, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
||||
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6))
|
||||
|
||||
|
||||
class ThresholdBackward1dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -148,16 +158,17 @@ class ThresholdBackward1dIntModule(torch.nn.Module):
|
|||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
||||
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10), tu.randint(4, high=8))
|
||||
|
||||
|
||||
class ThresholdBackward2dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -167,35 +178,39 @@ class ThresholdBackward2dIntModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
||||
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8))
|
||||
|
||||
|
||||
class ThresholdBackward3dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
||||
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8))
|
||||
|
||||
|
||||
class ThresholdBackward1dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -205,16 +220,17 @@ class ThresholdBackward1dFloatModule(torch.nn.Module):
|
|||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
||||
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), torch.randn(4))
|
||||
|
||||
|
||||
class ThresholdBackward2dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -224,35 +240,39 @@ class ThresholdBackward2dFloatModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
||||
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5), torch.randn(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
||||
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
|
||||
|
||||
|
||||
class ThresholdBackward1dMixedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -262,16 +282,17 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
|
|||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), tu.randint(4, high=10))
|
||||
|
||||
|
||||
class ThresholdBackward2dMixedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -281,29 +302,32 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
|
|||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dMixedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808,
|
||||
-9223372036854775808], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
||||
|
|
|
@ -18,7 +18,10 @@ class TypeConversionF32ToF64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float64)
|
||||
|
||||
|
@ -34,7 +37,10 @@ class TypeConversionF64ToF32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float64, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
@ -50,7 +56,9 @@ class TypeConversionI32ToI64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.int32, True)])
|
||||
@annotate_args([
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.int32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -66,7 +74,9 @@ class TypeConversionI64ToI32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)])
|
||||
@annotate_args([
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
@ -82,7 +92,9 @@ class TypeConversionI1ToI32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
@ -99,7 +111,9 @@ class TypeConversionI1ToI64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -116,7 +130,9 @@ class TypeConversionI1ToF32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
@ -133,7 +149,9 @@ class TypeConversionI1ToF64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([
|
||||
None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float64)
|
||||
|
||||
|
@ -153,7 +171,10 @@ class ToDtypeLayoutNoneModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.float64,
|
||||
|
@ -176,7 +197,10 @@ class ToDtypeLayoutStridedModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.float64,
|
||||
|
@ -245,7 +269,10 @@ class PrimsConvertElementTypeModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.prims.convert_element_type(x, dtype=torch.int64)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -30,11 +31,11 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
|||
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
|
||||
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(4, high=10).type(torch.int32),
|
||||
tu.randint(4, high=10))
|
||||
tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10))
|
||||
|
||||
|
||||
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -55,6 +56,7 @@ def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -75,6 +77,7 @@ def TypePromotionSameCategoryZeroRankWider_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -95,6 +98,7 @@ def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class TypePromotionAlphaWiderModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class ResNet18Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -24,7 +25,8 @@ class ResNet18Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 3, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
@ -36,6 +38,7 @@ def ResNet18Module_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class ResNet18StaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -58,6 +61,7 @@ def ResNet18StaticModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class IouOfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -84,7 +88,9 @@ class IouOfModule(torch.nn.Module):
|
|||
def IouOfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1024, 4), tu.rand(1024, 4))
|
||||
|
||||
|
||||
class MobilenetV2Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -95,11 +101,13 @@ class MobilenetV2Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 3, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.mobilenetv2.forward(img)
|
||||
|
||||
|
||||
# TODO (cathyzhyi) The runtime assertion for conv2d with group != 1 is exposed
|
||||
# after aten.hardtanh is implemented. Reenable once the the runtime assertion
|
||||
# is fixed.
|
||||
|
@ -107,7 +115,9 @@ class MobilenetV2Module(torch.nn.Module):
|
|||
def MobilenetV2Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 224, 224))
|
||||
|
||||
|
||||
class MobilenetV3Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
|
@ -118,7 +128,8 @@ class MobilenetV3Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 3, -9223372036854775808,
|
||||
-9223372036854775808], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.mobilenetv3.forward(img)
|
||||
|
|
Loading…
Reference in New Issue