mirror of https://github.com/llvm/torch-mlir
Add TestUtils.randint + replace torch.randint with tu.randint (#1276)
This commit adds a method to `TestUtils` that generates random integer tensors with a similar interface to the `TestUtils.rand`. This commit also replaces with `tu.randint` all test inputs generated with `torch.randint`.pull/1292/head
parent
e869e68559
commit
e153694c94
|
@ -80,7 +80,7 @@ class IsFloatingPointInt(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IsFloatingPointInt())
|
||||
def IsFloatingPointInt_False(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)))
|
||||
module.forward(tu.randint(3, 3, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -736,7 +736,7 @@ class EmbeddingModuleI64(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModuleI64())
|
||||
def EmbeddingModuleI64_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)))
|
||||
module.forward(tu.randint(3, 3, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -762,7 +762,7 @@ class EmbeddingModuleI32(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModuleI32())
|
||||
def EmbeddingModuleI32_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -787,7 +787,7 @@ class EmbeddingModuleI32Static(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModuleI32Static())
|
||||
def EmbeddingModuleI32Static_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 3, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -813,7 +813,7 @@ class EmbeddingModule1DIndices(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EmbeddingModule1DIndices())
|
||||
def EmbeddingModule1DIndices_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3,)).to(torch.int32))
|
||||
module.forward(tu.randint(3, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1331,7 +1331,7 @@ class DropoutEvalIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: DropoutEvalIntModule())
|
||||
def DropoutEvalIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(5, 10, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, low=5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1420,7 +1420,7 @@ class NumelZeroRankModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NumelZeroRankModule())
|
||||
def NumelZeroRankModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, []))
|
||||
module.forward(tu.randint(high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1646,7 +1646,7 @@ class ReturnTwoTensorF32I64(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReturnTwoTensorF32I64())
|
||||
def ReturnTwoTensorF32I64_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(5, (2, 3)))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, 3, high=5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1669,7 +1669,7 @@ class IndexTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorModule())
|
||||
def IndexTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5), torch.randint(4, (2, 3)))
|
||||
module.forward(tu.rand(5), tu.randint(2, 3, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1692,7 +1692,7 @@ class IndexTensorModule3dInput(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorModule3dInput())
|
||||
def IndexTensorModule3dInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(3, (2, 3)))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1715,7 +1715,7 @@ class IndexTensorSelectDimModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorSelectDimModule())
|
||||
def IndexTensorSelectDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 6), torch.randint(3, (2, 3)))
|
||||
module.forward(tu.rand(2, 4, 6), tu.randint(2, 3, high=3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -1738,7 +1738,7 @@ class IndexTensorMultiInput(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInput())
|
||||
def IndexTensorMultiInput_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(3, (3, 3)), torch.randint(3, (3,)))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3), tu.randint(3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1762,7 +1762,7 @@ class IndexTensorMultiInputOneDim(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim())
|
||||
def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), torch.randint(3, (3,)))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1791,8 +1791,8 @@ class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic())
|
||||
def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
|
||||
torch.randint(3, (3, )))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4),
|
||||
tu.randint(3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1822,8 +1822,8 @@ class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):
|
|||
module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic())
|
||||
def IndexTensorMultiInputNonContiguousOneDimDynamic_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
|
||||
torch.randint(3, (3, )))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4),
|
||||
tu.randint(3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1852,8 +1852,8 @@ class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic())
|
||||
def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(2, (6, 2)),
|
||||
torch.randint(3, (2, )))
|
||||
module.forward(tu.rand(5, 4, 3), tu.randint(6, 2, high=2),
|
||||
tu.randint(2, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1880,8 +1880,8 @@ class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):
|
|||
IndexTensorMultiInputNonContiguousMultipleStaticDims())
|
||||
def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 1)),
|
||||
torch.randint(1, (1, 3)), torch.randint(1, (4, 3)))
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3),
|
||||
tu.randint(1, 3, high=1), tu.randint(4, 3, high=1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1905,7 +1905,7 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous())
|
||||
def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 2)), torch.randint(1, (4, 2,)))
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1931,9 +1931,9 @@ class IndexTensorMultiInputThreeIndexers(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers())
|
||||
def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 2, 4, 4, 5, 3),
|
||||
torch.randint(3, (8, 4, 2,)),
|
||||
torch.randint(4, (8, 1, 1,)),
|
||||
torch.randint(2, (4, 2,)))
|
||||
tu.randint(8, 4, 2, high=3),
|
||||
tu.randint(8, 1, 1, high=4),
|
||||
tu.randint(4, 2, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1957,7 +1957,7 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter())
|
||||
def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (2, 2)), torch.randint(2, [2]))
|
||||
module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3), tu.randint(2, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2089,7 +2089,7 @@ class HardTanhIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: HardTanhIntModule())
|
||||
def HardTanhIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-5, 5, (100, 100)))
|
||||
module.forward(tu.randint(100, 100, low=-5, high=5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2111,7 +2111,7 @@ class BincountModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BincountModule())
|
||||
def BincountModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (1000, )))
|
||||
module.forward(tu.randint(1000, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2133,7 +2133,7 @@ class BincountStaticSizeModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BincountStaticSizeModule())
|
||||
def BincountStaticSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (200, )))
|
||||
module.forward(tu.randint(200, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2155,7 +2155,7 @@ class BincountMinlengthModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BincountMinlengthModule())
|
||||
def BincountMinlengthModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(5, (20, )))
|
||||
module.forward(tu.randint(20, high=5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2198,8 +2198,8 @@ class ExpandAsIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ExpandAsIntModule())
|
||||
def ExpandAsIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (1, 1, 1)),
|
||||
torch.randint(200, (4, 5, 6)))
|
||||
module.forward(tu.randint(1, 1, 1, high=100),
|
||||
tu.randint(4, 5, 6, high=200))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2262,7 +2262,7 @@ class CopyWithDifferentDTypesModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: CopyWithDifferentDTypesModule())
|
||||
def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 2, 4)), tu.rand(3, 2, 4))
|
||||
module.forward(tu.randint(3, 2, 4, high=100), tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module):
|
||||
|
@ -2283,7 +2283,7 @@ class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: CopyWithDifferentDTypesAndSizesModule())
|
||||
def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), torch.randint(1000, (3, 2, 1)))
|
||||
module.forward(tu.rand(3, 2, 4), tu.randint(3, 2, 1, high=1000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2451,7 +2451,7 @@ class ScalarImplicitIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ScalarImplicitIntModule())
|
||||
def ScalarImplicitIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2517,7 +2517,7 @@ class BaddbmmDifferentDtypesModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BaddbmmDifferentDtypesModule())
|
||||
def BaddbmmDifferentDtypesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4, 5)), tu.rand(3, 4, 6),
|
||||
module.forward(tu.randint(3, 4, 5, high=10), tu.rand(3, 4, 6),
|
||||
tu.rand(3, 6, 5))
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class TensorToIntZeroRank(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TensorToIntZeroRank())
|
||||
def TensorToIntZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, ()))
|
||||
module.forward(tu.randint(high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -45,7 +45,7 @@ class TensorToInt(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (1, 1)))
|
||||
module.forward(tu.randint(1, 1, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -346,7 +346,7 @@ class EmptyLikeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
|
||||
def EmptyLikeModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)))
|
||||
module.forward(tu.randint(3, 5, high=10))
|
||||
|
||||
|
||||
class EmptyLikeMemoryFormatModule(torch.nn.Module):
|
||||
|
@ -446,7 +446,7 @@ class ZerosLikeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ZerosLikeIntModule())
|
||||
def ZerosLikeModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)))
|
||||
module.forward(tu.randint(3, 5, high=10))
|
||||
|
||||
|
||||
class ZerosLikeFloatModule(torch.nn.Module):
|
||||
|
@ -525,7 +525,7 @@ class OnesLikeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: OnesLikeIntModule())
|
||||
def OnesLikeModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)))
|
||||
module.forward(tu.randint(3, 5, high=10))
|
||||
|
||||
|
||||
class OnesLikeFloatModule(torch.nn.Module):
|
||||
|
@ -702,7 +702,7 @@ class NewZerosModuleFloat2D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewZerosModuleFloat2D())
|
||||
def NewZerosModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3, 4)))
|
||||
module.forward(tu.randint(2, 3, 4, high=10))
|
||||
|
||||
|
||||
class NewZerosModuleFloat3D(torch.nn.Module):
|
||||
|
@ -721,7 +721,7 @@ class NewZerosModuleFloat3D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewZerosModuleFloat3D())
|
||||
def NewZerosModuleFloat3D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
module.forward(tu.randint(2, 3, high=10))
|
||||
|
||||
|
||||
class NewZerosModuleFalsePinMemory(torch.nn.Module):
|
||||
|
@ -742,7 +742,7 @@ class NewZerosModuleFalsePinMemory(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory())
|
||||
def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
module.forward(tu.randint(2, 3, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -821,7 +821,7 @@ class NewOnesModuleFloat2D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewOnesModuleFloat2D())
|
||||
def NewOnesModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3, 4)))
|
||||
module.forward(tu.randint(2, 3, 4, high=10))
|
||||
|
||||
|
||||
class NewOnesModuleFloat3D(torch.nn.Module):
|
||||
|
@ -840,7 +840,7 @@ class NewOnesModuleFloat3D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewOnesModuleFloat3D())
|
||||
def NewOnesModuleFloat3D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
module.forward(tu.randint(2, 3, high=10))
|
||||
|
||||
|
||||
class NewOnesModuleFalsePinMemory(torch.nn.Module):
|
||||
|
@ -861,7 +861,7 @@ class NewOnesModuleFalsePinMemory(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory())
|
||||
def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
module.forward(tu.randint(2, 3, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1016,7 +1016,7 @@ class FullLikeModuleInt2D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: FullLikeModuleInt2D())
|
||||
def FullLikeModuleInt2D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5)))
|
||||
module.forward(tu.randint(4, 5, high=10))
|
||||
|
||||
|
||||
class FullLikeModuleInt3D(torch.nn.Module):
|
||||
|
@ -1035,7 +1035,7 @@ class FullLikeModuleInt3D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: FullLikeModuleInt3D())
|
||||
def FullLikeModuleInt3D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, 4, 5)).to(torch.int32))
|
||||
module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32))
|
||||
|
||||
|
||||
class FullLikeModuleInt2DStatic(torch.nn.Module):
|
||||
|
@ -1054,7 +1054,7 @@ class FullLikeModuleInt2DStatic(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: FullLikeModuleInt2DStatic())
|
||||
def FullLikeModuleInt2DStatic_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5)))
|
||||
module.forward(tu.randint(4, 5, high=10))
|
||||
|
||||
|
||||
class FullLikeModuleFloat2D(torch.nn.Module):
|
||||
|
@ -1133,7 +1133,7 @@ class FullLikeModuleFalsePinMemory(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: FullLikeModuleFalsePinMemory())
|
||||
def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, 4)))
|
||||
module.forward(tu.randint(10, 4, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1174,7 +1174,7 @@ class ZeroInt32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ZeroInt32Module())
|
||||
def ZeroInt32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(10, 4, high=100).to(dtype=torch.int32))
|
||||
|
||||
|
||||
class ZeroInt64Module(torch.nn.Module):
|
||||
|
@ -1193,7 +1193,7 @@ class ZeroInt64Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ZeroInt64Module())
|
||||
def ZeroInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, 4)))
|
||||
module.forward(tu.randint(10, 4, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1274,7 +1274,7 @@ class NewEmptyModuleFloat2D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleFloat2D())
|
||||
def NewEmptyModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3, 4)))
|
||||
module.forward(tu.randint(2, 3, 4, high=10))
|
||||
|
||||
|
||||
class NewEmptyModuleFloat3D(torch.nn.Module):
|
||||
|
@ -1294,7 +1294,7 @@ class NewEmptyModuleFloat3D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleFloat3D())
|
||||
def NewEmptyModuleFloat3D_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
module.forward(tu.randint(2, 3, high=10))
|
||||
|
||||
|
||||
class NewEmptyModuleFalsePinMemory(torch.nn.Module):
|
||||
|
@ -1315,7 +1315,7 @@ class NewEmptyModuleFalsePinMemory(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory())
|
||||
def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)))
|
||||
module.forward(tu.randint(2, 3, high=10))
|
||||
|
||||
|
||||
class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module):
|
||||
|
@ -1354,7 +1354,7 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype())
|
||||
def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
|
||||
module.forward(tu.randint(2, 3, high=10).to(torch.int32))
|
||||
|
||||
|
||||
class NewEmptyModuleLayoutIntDtype(torch.nn.Module):
|
||||
|
@ -1373,7 +1373,7 @@ class NewEmptyModuleLayoutIntDtype(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype())
|
||||
def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3)).to(torch.int32))
|
||||
module.forward(tu.randint(2, 3, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1397,7 +1397,7 @@ class MaskedFillScalarDefaultModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MaskedFillScalarDefaultModule())
|
||||
def MaskedFillScalarDefaultModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3),
|
||||
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))
|
||||
tu.randint(2, 3, high=2).to(dtype=torch.bool))
|
||||
|
||||
|
||||
class MaskedFillScalarIntValueModule(torch.nn.Module):
|
||||
|
@ -1418,7 +1418,7 @@ class MaskedFillScalarIntValueModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: MaskedFillScalarIntValueModule())
|
||||
def MaskedFillScalarIntValueModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3),
|
||||
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))
|
||||
tu.randint(2, 3, high=2).to(dtype=torch.bool))
|
||||
|
||||
|
||||
class MaskedFillScalarFloatValueModule(torch.nn.Module):
|
||||
|
@ -1438,8 +1438,8 @@ class MaskedFillScalarFloatValueModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: MaskedFillScalarFloatValueModule())
|
||||
def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 10, (2, 3)),
|
||||
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool))
|
||||
module.forward(tu.randint(2, 3, low=-10, high=10),
|
||||
tu.randint(2, 3, high=2).to(dtype=torch.bool))
|
||||
|
||||
|
||||
class MaskedFillTensorFloatValueModule(torch.nn.Module):
|
||||
|
@ -1460,5 +1460,5 @@ class MaskedFillTensorFloatValueModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule())
|
||||
def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 10, (2, 3)),
|
||||
torch.randint(0, 2, (2, 3)).to(dtype=torch.bool), tu.rand())
|
||||
module.forward(tu.randint(2, 3, low=-10, high=10),
|
||||
tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand())
|
||||
|
|
|
@ -32,7 +32,7 @@ class TorchPrimLoopForLikeModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule())
|
||||
def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(0, 10, (6, 8)))
|
||||
module.forward(tu.randint(6, 8, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
||||
|
@ -54,4 +54,4 @@ class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule())
|
||||
def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(0, 10, (6, 8)))
|
||||
module.forward(tu.randint(6, 8, high=10))
|
||||
|
|
|
@ -57,7 +57,7 @@ class ElementwiseUnaryIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseUnaryIntModule())
|
||||
def ElementwiseUnaryIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -428,7 +428,7 @@ class ElementwiseSigmoidIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule())
|
||||
def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 5, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -474,7 +474,7 @@ class ElementwiseMinimumIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMinimumIntModule())
|
||||
def ElementwiseMinimumIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5)))
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -520,7 +520,7 @@ class ElementwiseMaximumIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMaximumIntModule())
|
||||
def ElementwiseMaximumIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5)))
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -663,7 +663,7 @@ class RsubIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: RsubIntModule())
|
||||
def RsubIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -685,7 +685,7 @@ class RsubIntModule_noalpha(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: RsubIntModule_noalpha())
|
||||
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -707,7 +707,7 @@ class ElementwiseMulScalarIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule())
|
||||
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -751,7 +751,7 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -798,7 +798,7 @@ class ElementwiseMulTensorIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseMulTensorIntModule())
|
||||
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4]))
|
||||
tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -845,7 +845,7 @@ class ElementwiseAtan2TensorIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
|
||||
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4]))
|
||||
tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -868,8 +868,8 @@ class ElementwiseAtan2FloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
|
||||
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32),
|
||||
tu.rand(4, 4).double())
|
||||
module.forward(tu.randint(4, 4, low=1, high=10).to(torch.int32),
|
||||
tu.rand(4, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -913,7 +913,7 @@ class ElementwiseLogIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLogIntModule())
|
||||
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -978,7 +978,7 @@ class ElementwiseErfIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseErfIntModule())
|
||||
def ElementwiseErfIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1022,7 +1022,7 @@ class ElementwiseSqrtIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSqrtIntModule())
|
||||
def ElementwiseSqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1170,7 +1170,7 @@ class ElementwiseLog2IntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLog2IntModule())
|
||||
def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1214,7 +1214,7 @@ class ElementwiseRsqrtIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule())
|
||||
def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1280,7 +1280,7 @@ class ElementwiseReciprocalIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule())
|
||||
def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (4,), dtype=torch.int32))
|
||||
module.forward(tu.randint(4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1323,7 +1323,7 @@ class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
|
||||
def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3,), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1366,7 +1366,7 @@ class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int())
|
||||
def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 2), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 2, high=10).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -1478,8 +1478,8 @@ class ElementwiseAndIntegerModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseAndIntegerModule())
|
||||
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
||||
torch.randint(-10, 10, (3, 4)))
|
||||
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
|
||||
tu.randint(3, 4, low=-10, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1501,7 +1501,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
|
||||
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, high=10).to(dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1545,7 +1545,7 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module())
|
||||
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1567,7 +1567,7 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
|
||||
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
|
||||
module.forward(tu.randint(2, 3, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1677,7 +1677,7 @@ class ElementwiseExpIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseExpIntModule())
|
||||
def ElementwiseExpIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1721,7 +1721,7 @@ class ElementwiseExpm1IntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule())
|
||||
def ElementwiseExpm1IntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1765,7 +1765,7 @@ class ElementwiseSinIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSinIntModule())
|
||||
def ElementwiseSinIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1809,7 +1809,7 @@ class ElementwiseCosIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseCosIntModule())
|
||||
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1924,7 +1924,7 @@ class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule())
|
||||
def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(3, 10, (2, 3, 4, 5)), torch.randint(10, 100, (2, 3, 4, 5)))
|
||||
module.forward(tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -1962,7 +1962,7 @@ class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule())
|
||||
def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.neg(torch.randint(3, 10, (2, 3, 4, 5))), torch.neg(torch.randint(10, 100, (2, 3, 4, 5))))
|
||||
module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -1981,7 +1981,7 @@ class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule())
|
||||
def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(3, (3,)), torch.randint(3, (4, 3)))
|
||||
module.forward(tu.randint(3, high=3), tu.randint(4, 3, high=3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -45,7 +45,7 @@ class ElementwiseGtIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
|
||||
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -64,7 +64,7 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
|
||||
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -102,7 +102,7 @@ class ElementwiseGeIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeIntScalarModule())
|
||||
def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -121,7 +121,7 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeMixedIntScalarModule())
|
||||
def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -180,7 +180,7 @@ class ElementwiseGtIntTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule())
|
||||
def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -218,7 +218,7 @@ class ElementwiseLtIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule())
|
||||
def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -238,7 +238,7 @@ class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseLtDiffWidthScalarModule())
|
||||
def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -276,7 +276,7 @@ class ElementwiseLeIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeIntScalarModule())
|
||||
def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -295,7 +295,7 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeMixedIntScalarModule())
|
||||
def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -354,7 +354,7 @@ class ElementwiseLtIntTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule())
|
||||
def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, )))
|
||||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -393,7 +393,7 @@ class ElementwiseEqIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule())
|
||||
def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (5, 8)))
|
||||
module.forward(tu.randint(5, 8, low=2, high=4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -413,7 +413,7 @@ class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseEqDiffWidthScalarModule())
|
||||
def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32))
|
||||
module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -455,7 +455,7 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule())
|
||||
def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, )))
|
||||
module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -494,7 +494,7 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule())
|
||||
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(2, 4, (8, 5)))
|
||||
module.forward(tu.randint(8, 5, low=2, high=4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -90,17 +90,12 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature())
|
||||
def HBC_basic(module, tu: TestUtils):
|
||||
logits = torch.rand(NUM_LOGITS, dtype=torch.float)
|
||||
segment_lengths: Tensor = torch.randint(
|
||||
0, 2, (NUM_LOGITS,), dtype=torch.int)
|
||||
segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int)
|
||||
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
|
||||
segment_offsets: Tensor = torch.cat(
|
||||
(torch.tensor([0]), segment_offsets), 0)
|
||||
num_values: int = int(torch.sum(segment_lengths).item())
|
||||
segment_values: Tensor = torch.randint(
|
||||
0,
|
||||
NUM_SEGMENTS,
|
||||
(num_values,),
|
||||
)
|
||||
segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS)
|
||||
segment_values = torch.cat(
|
||||
(segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0)
|
||||
module.forward(segment_values.int(), segment_offsets.int(), logits)
|
||||
|
|
|
@ -34,7 +34,7 @@ class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule())
|
||||
def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250))
|
||||
module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250))
|
||||
|
||||
|
||||
class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -59,7 +59,7 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule())
|
||||
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8))
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -84,7 +84,7 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule())
|
||||
def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )),
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
|
@ -113,8 +113,8 @@ class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule())
|
||||
def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )),
|
||||
torch.randint(10000, (300, )))
|
||||
module.forward(tu.randint(200, high=1000), tu.randint(300, high=100),
|
||||
tu.randint(300, high=10000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -142,7 +142,7 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl1DFloatAccumulateModule())
|
||||
def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500))
|
||||
module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500))
|
||||
|
||||
|
||||
class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module):
|
||||
|
@ -167,7 +167,7 @@ class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl2DFloatAccumulateModule())
|
||||
def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8))
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
||||
|
@ -192,7 +192,7 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl3DFloatAccumulateModule())
|
||||
def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )),
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
|
@ -220,8 +220,8 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule())
|
||||
def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )),
|
||||
torch.randint(1000, (10, )))
|
||||
module.forward(tu.randint(10, high=100), tu.randint(10, high=10),
|
||||
tu.randint(10, high=1000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -248,7 +248,7 @@ class IndexPut1DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPut1DFloatNonAccumulateModule())
|
||||
def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250))
|
||||
module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250))
|
||||
|
||||
|
||||
class IndexPut2DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -272,7 +272,7 @@ class IndexPut2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPut2DFloatNonAccumulateModule())
|
||||
def IndexPut2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8))
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -296,7 +296,7 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPut3DFloatNonAccumulateModule())
|
||||
def IndexPut3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )),
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
|
@ -323,8 +323,8 @@ class IndexPut1DIntNonAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule())
|
||||
def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )),
|
||||
torch.randint(10000, (300, )))
|
||||
module.forward(tu.randint(200, high=1000), tu.randint(300, high=100),
|
||||
tu.randint(300, high=10000))
|
||||
|
||||
|
||||
class IndexPut2DIntNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -347,8 +347,8 @@ class IndexPut2DIntNonAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut2DIntNonAccumulateModule())
|
||||
def IndexPut2DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8)))
|
||||
module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, high=1000))
|
||||
|
||||
|
||||
class IndexPut3DIntNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -371,8 +371,8 @@ class IndexPut3DIntNonAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut3DIntNonAccumulateModule())
|
||||
def IndexPut3DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8, 6)))
|
||||
module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, 6, high=1000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -398,7 +398,7 @@ class IndexPut1DFloatAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule())
|
||||
def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500))
|
||||
module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500))
|
||||
|
||||
|
||||
class IndexPut2DFloatAccumulateModule(torch.nn.Module):
|
||||
|
@ -421,7 +421,7 @@ class IndexPut2DFloatAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut2DFloatAccumulateModule())
|
||||
def IndexPut2DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8))
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPut3DFloatAccumulateModule(torch.nn.Module):
|
||||
|
@ -444,7 +444,7 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut3DFloatAccumulateModule())
|
||||
def IndexPut3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )),
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
|
@ -471,8 +471,8 @@ class IndexPut1DIntAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule())
|
||||
def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )),
|
||||
torch.randint(1000, (10, )))
|
||||
module.forward(tu.randint(10, high=100), tu.randint(10, high=10),
|
||||
tu.randint(10, high=1000))
|
||||
|
||||
|
||||
class IndexPut2DIntAccumulateModule(torch.nn.Module):
|
||||
|
@ -495,8 +495,8 @@ class IndexPut2DIntAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut2DIntAccumulateModule())
|
||||
def IndexPut2DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8)))
|
||||
module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, high=1000))
|
||||
|
||||
|
||||
class IndexPut3DIntAccumulateModule(torch.nn.Module):
|
||||
|
@ -519,8 +519,8 @@ class IndexPut3DIntAccumulateModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexPut3DIntAccumulateModule())
|
||||
def IndexPut3DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8, 6)))
|
||||
module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, 6, high=1000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -548,7 +548,7 @@ class IndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin1DFloatNonAccumulateModule())
|
||||
def IndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250))
|
||||
module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250))
|
||||
|
||||
|
||||
class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -572,7 +572,7 @@ class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin2DFloatNonAccumulateModule())
|
||||
def IndexPutHackedTwin2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8))
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -596,7 +596,7 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule())
|
||||
def IndexPutHackedTwin3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )),
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
|
@ -624,8 +624,8 @@ class IndexPutHackedTwin1DIntNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin1DIntNonAccumulateModule())
|
||||
def IndexPutHackedTwin1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )),
|
||||
torch.randint(10000, (300, )))
|
||||
module.forward(tu.randint(200, high=1000), tu.randint(300, high=100),
|
||||
tu.randint(300, high=10000))
|
||||
|
||||
|
||||
class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -649,8 +649,8 @@ class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin2DIntNonAccumulateModule())
|
||||
def IndexPutHackedTwin2DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8)))
|
||||
module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, high=1000))
|
||||
|
||||
|
||||
class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module):
|
||||
|
@ -674,8 +674,8 @@ class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin3DIntNonAccumulateModule())
|
||||
def IndexPutHackedTwin3DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8, 6)))
|
||||
module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, 6, high=1000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -700,7 +700,7 @@ class IndexPutHackedTwin1DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin1DFloatAccumulateModule())
|
||||
def IndexPutHackedTwin1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500))
|
||||
module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500))
|
||||
|
||||
|
||||
class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module):
|
||||
|
@ -722,7 +722,7 @@ class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin2DFloatAccumulateModule())
|
||||
def IndexPutHackedTwin2DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8))
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
|
||||
|
@ -744,7 +744,7 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule())
|
||||
def IndexPutHackedTwin3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )),
|
||||
module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
|
@ -770,8 +770,8 @@ class IndexPutHackedTwin1DIntAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin1DIntAccumulateModule())
|
||||
def IndexPutHackedTwin1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )),
|
||||
torch.randint(1000, (10, )))
|
||||
module.forward(tu.randint(10, high=100), tu.randint(10, high=10),
|
||||
tu.randint(10, high=1000))
|
||||
|
||||
|
||||
class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module):
|
||||
|
@ -793,8 +793,8 @@ class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin2DIntAccumulateModule())
|
||||
def IndexPutHackedTwin2DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8)))
|
||||
module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, high=1000))
|
||||
|
||||
|
||||
class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module):
|
||||
|
@ -816,5 +816,5 @@ class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: IndexPutHackedTwin3DIntAccumulateModule())
|
||||
def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )),
|
||||
torch.randint(1000, (5, 8, 6)))
|
||||
module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, 6, high=1000))
|
||||
|
|
|
@ -34,7 +34,7 @@ class NllLossModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NllLossModule())
|
||||
def NllLossModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_mean(torch.nn.Module):
|
||||
|
@ -58,7 +58,7 @@ class NllLossModule_mean(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||
def NllLossModule_mean_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_sum(torch.nn.Module):
|
||||
|
@ -82,7 +82,7 @@ class NllLossModule_sum(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||
def NllLossModule_sum_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
|
||||
class NllLossModule_1D(torch.nn.Module):
|
||||
|
@ -106,7 +106,7 @@ class NllLossModule_1D(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||
def NllLossModule_1D_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), torch.randint(0, 3, ()))
|
||||
module.forward(tu.rand(3), tu.randint(high=3))
|
||||
|
||||
|
||||
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
||||
|
@ -131,7 +131,7 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,)))
|
||||
module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3))
|
||||
|
||||
class NllLossModule_backward(torch.nn.Module):
|
||||
|
||||
|
|
|
@ -489,7 +489,7 @@ class MaxPool2dWithIndicesBackwardStatic4DModule(torch.nn.Module):
|
|||
module_factory=lambda: MaxPool2dWithIndicesBackwardStatic4DModule())
|
||||
def MaxPool2dWithIndicesBackwardStatic4DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5),
|
||||
torch.randint(16, (2, 4, 7, 6)))
|
||||
tu.randint(2, 4, 7, 6, high=16))
|
||||
|
||||
|
||||
class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module):
|
||||
|
@ -519,7 +519,7 @@ class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module):
|
|||
module_factory=lambda: MaxPool2dWithIndicesBackwardStatic3DModule())
|
||||
def MaxPool2dWithIndicesBackwardStatic3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 7, 6), tu.rand(4, 6, 5),
|
||||
torch.randint(16, (4, 7, 6)))
|
||||
tu.randint(4, 7, 6, high=16))
|
||||
|
||||
|
||||
class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module):
|
||||
|
@ -549,7 +549,7 @@ class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module):
|
|||
module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic4DModule())
|
||||
def MaxPool2dWithIndicesBackwardDynamic4DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5),
|
||||
torch.randint(16, (2, 4, 7, 6)))
|
||||
tu.randint(2, 4, 7, 6, high=16))
|
||||
|
||||
|
||||
class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
|
||||
|
@ -579,7 +579,7 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
|
|||
module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic3DModule())
|
||||
def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5),
|
||||
torch.randint(16, (2, 7, 6)))
|
||||
tu.randint(2, 7, 6, high=16))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -632,7 +632,7 @@ class AvgPool2dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dIntModule())
|
||||
def AvgPool2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (2, 4, 20, 20)))
|
||||
module.forward(tu.randint(2, 4, 20, 20, high=100))
|
||||
|
||||
|
||||
class AvgPool2dStaticModule(torch.nn.Module):
|
||||
|
|
|
@ -140,7 +140,7 @@ class ReduceSumUnsignedIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceSumUnsignedIntModule())
|
||||
def ReduceSumUnsignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(0, 100, (3, 4, 5)))
|
||||
module.forward(tu.randint(3, 4, 5, low=0, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -159,7 +159,7 @@ class ReduceSumSignedIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceSumSignedIntModule())
|
||||
def ReduceSumSignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, (3, 4, 5)))
|
||||
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -178,7 +178,7 @@ class ReduceSumDtypeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDtypeIntModule())
|
||||
def ReduceSumDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4, 5)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -197,7 +197,7 @@ class ReduceSumDimIntListIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListIntModule())
|
||||
def ReduceSumDimIntListIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4, 5)))
|
||||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -216,7 +216,7 @@ class ReduceSumDimIntListDtypeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeIntModule())
|
||||
def ReduceSumDimIntListDtypeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4, 5)).to(torch.int32))
|
||||
module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -235,7 +235,7 @@ class ReduceSumDimIntListKeepDimIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimIntModule())
|
||||
def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4, 5)))
|
||||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -383,7 +383,7 @@ class ReduceMaxSignedIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxSignedIntModule())
|
||||
def ReduceMaxSignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, (3, 4, 5)))
|
||||
module.forward(tu.randint(3, 4, 5, low=-100, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -401,7 +401,7 @@ class ReduceMaxUnsignedIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule())
|
||||
def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4, 5)))
|
||||
module.forward(tu.randint(3, 4, 5, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ class AddIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AddIntModule())
|
||||
def AddIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -52,7 +52,7 @@ class SubIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SubIntModule())
|
||||
def SubIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -98,7 +98,7 @@ class MulIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: MulIntModule())
|
||||
def MulIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -172,7 +172,7 @@ class SqrtIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SqrtIntModule())
|
||||
def SqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, ()))
|
||||
module.forward(tu.randint(high=10))
|
||||
|
||||
|
||||
class SqrtIntConstantModule(torch.nn.Module):
|
||||
|
@ -273,7 +273,7 @@ class BoolIntFalseModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BoolIntFalseModule())
|
||||
def BoolIntFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 100, ()))
|
||||
module.forward(tu.randint(low=1, high=100))
|
||||
|
||||
|
||||
class BoolIntTrueModule(torch.nn.Module):
|
||||
|
@ -292,7 +292,7 @@ class BoolIntTrueModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: BoolIntTrueModule())
|
||||
def BoolIntTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 100, ()))
|
||||
module.forward(tu.randint(low=1, high=100))
|
||||
|
||||
|
||||
class BoolIntConstantModule(torch.nn.Module):
|
||||
|
|
|
@ -29,7 +29,7 @@ class NeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NeIntModule())
|
||||
def NeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -52,7 +52,7 @@ class EqIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EqIntModule())
|
||||
def EqIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -75,7 +75,7 @@ class GtIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GtIntModule())
|
||||
def GtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -98,7 +98,7 @@ class GeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GeIntModule())
|
||||
def GeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ()))
|
||||
module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -144,7 +144,7 @@ class GeFloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GeFloatIntModule())
|
||||
def GeFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), torch.randint(-100, 100, ()))
|
||||
module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -167,7 +167,7 @@ class NeFloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NeFloatIntModule())
|
||||
def NeFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), torch.randint(-100, 100, ()))
|
||||
module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -190,4 +190,4 @@ class GtFloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GtFloatIntModule())
|
||||
def GtFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), torch.randint(-100, 100, ()))
|
||||
module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100))
|
||||
|
|
|
@ -229,7 +229,7 @@ class SelectIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SelectIntModule())
|
||||
def SelectIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (5,5)))
|
||||
module.forward(tu.randint(5,5, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class TableBatchEmbeddingModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TableBatchEmbeddingModule())
|
||||
def TableBatchEmbeddingModule_basic(module, tu: TestUtils):
|
||||
indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,))
|
||||
indices = tu.randint(NUM_TABLES * BATCH_SIZE * BAG_SIZE, high=NUM_EMBEDDINGS)
|
||||
offsets = torch.cumsum(
|
||||
torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0)
|
||||
module.forward(indices, offsets)
|
||||
|
|
|
@ -27,7 +27,7 @@ class Threshold1dIntI32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
||||
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4,), dtype=torch.int32))
|
||||
module.forward(tu.randint(4, high=10).to(torch.int32))
|
||||
|
||||
|
||||
class Threshold1dIntModule(torch.nn.Module):
|
||||
|
@ -45,7 +45,7 @@ class Threshold1dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
||||
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4,)))
|
||||
module.forward(tu.randint(4, high=10))
|
||||
|
||||
|
||||
class Threshold2dIntModule(torch.nn.Module):
|
||||
|
@ -63,7 +63,7 @@ class Threshold2dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
||||
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5)))
|
||||
module.forward(tu.randint(4, 5, high=10))
|
||||
|
||||
|
||||
class Threshold3dIntModule(torch.nn.Module):
|
||||
|
@ -81,7 +81,7 @@ class Threshold3dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
||||
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5, 6)))
|
||||
module.forward(tu.randint(4, 5, 6, high=10))
|
||||
|
||||
|
||||
class Threshold1dFloatModule(torch.nn.Module):
|
||||
|
@ -154,7 +154,7 @@ class ThresholdBackward1dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
||||
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4,)), torch.randint(8, (4,)))
|
||||
module.forward(tu.randint(4, high=10), tu.randint(4, high=8))
|
||||
|
||||
|
||||
class ThresholdBackward2dIntModule(torch.nn.Module):
|
||||
|
@ -173,7 +173,7 @@ class ThresholdBackward2dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
||||
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5)))
|
||||
module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8))
|
||||
|
||||
|
||||
class ThresholdBackward3dIntModule(torch.nn.Module):
|
||||
|
@ -192,7 +192,7 @@ class ThresholdBackward3dIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
||||
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6)))
|
||||
module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8))
|
||||
|
||||
|
||||
class ThresholdBackward1dFloatModule(torch.nn.Module):
|
||||
|
@ -268,7 +268,7 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), torch.randint(10, (4,)))
|
||||
module.forward(torch.randn(4), tu.randint(4, high=10))
|
||||
|
||||
|
||||
class ThresholdBackward2dMixedModule(torch.nn.Module):
|
||||
|
@ -287,7 +287,7 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5))
|
||||
module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dMixedModule(torch.nn.Module):
|
||||
|
@ -306,4 +306,4 @@ class ThresholdBackward3dMixedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6)))
|
||||
module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
||||
|
|
|
@ -57,7 +57,7 @@ class TypeConversionI32ToI64Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI32ToI64Module())
|
||||
def TypeConversionI32ToI64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(5, [2, 3]).type(torch.int32))
|
||||
module.forward(tu.randint(2, 3, high=5).type(torch.int32))
|
||||
|
||||
|
||||
class TypeConversionI64ToI32Module(torch.nn.Module):
|
||||
|
@ -73,7 +73,7 @@ class TypeConversionI64ToI32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI64ToI32Module())
|
||||
def TypeConversionI64ToI32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(5, [2, 3]))
|
||||
module.forward(tu.randint(2, 3, high=5))
|
||||
|
||||
|
||||
class TypeConversionI1ToI32Module(torch.nn.Module):
|
||||
|
@ -89,7 +89,7 @@ class TypeConversionI1ToI32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI1ToI32Module())
|
||||
def TypeConversionI1ToI32Module_basic(module, tu: TestUtils):
|
||||
tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool)
|
||||
tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool)
|
||||
module.forward(tensor)
|
||||
|
||||
|
||||
|
@ -106,7 +106,7 @@ class TypeConversionI1ToI64Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI1ToI64Module())
|
||||
def TypeConversionI1ToI64Module_basic(module, tu: TestUtils):
|
||||
tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool)
|
||||
tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool)
|
||||
module.forward(tensor)
|
||||
|
||||
|
||||
|
@ -123,7 +123,7 @@ class TypeConversionI1ToF32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI1ToF32Module())
|
||||
def TypeConversionI1ToF32Module_basic(module, tu: TestUtils):
|
||||
tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool)
|
||||
tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool)
|
||||
module.forward(tensor)
|
||||
|
||||
|
||||
|
@ -140,7 +140,7 @@ class TypeConversionI1ToF64Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TypeConversionI1ToF64Module())
|
||||
def TypeConversionI1ToF64Module_basic(module, tu: TestUtils):
|
||||
tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool)
|
||||
tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool)
|
||||
module.forward(tensor)
|
||||
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
|||
module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule())
|
||||
def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(10, [4]).type(torch.int32),
|
||||
torch.randint(10, [4]))
|
||||
tu.randint(4, high=10).type(torch.int32),
|
||||
tu.randint(4, high=10))
|
||||
|
||||
|
||||
class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
||||
|
@ -51,7 +51,7 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, [4]), torch.randn(4))
|
||||
module.forward(tu.randint(4, high=10), torch.randn(4))
|
||||
|
||||
|
||||
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
||||
|
@ -91,7 +91,7 @@ class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionZeroRankHigherCategoryModule())
|
||||
def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, [4]), tu.rand())
|
||||
module.forward(tu.randint(4, high=10), tu.rand())
|
||||
|
||||
|
||||
class TypePromotionAlphaWiderModule(torch.nn.Module):
|
||||
|
|
|
@ -182,6 +182,9 @@ class TestUtils:
|
|||
def rand(self, *sizes, low=0.0, high=1.0):
|
||||
return torch.empty(sizes).uniform_(low, high)
|
||||
|
||||
def randint(self, *sizes, low=0, high=10):
|
||||
return torch.randint(low, high, sizes)
|
||||
|
||||
def nans(self, *sizes):
|
||||
vals = torch.empty(sizes)
|
||||
vals[...] = torch.nan
|
||||
|
|
Loading…
Reference in New Issue