diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 6c155fa1e..3ec6130d5 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -80,7 +80,7 @@ class IsFloatingPointInt(torch.nn.Module): @register_test_case(module_factory=lambda: IsFloatingPointInt()) def IsFloatingPointInt_False(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3))) + module.forward(tu.randint(3, 3, high=100)) # ============================================================================== @@ -736,7 +736,7 @@ class EmbeddingModuleI64(torch.nn.Module): @register_test_case(module_factory=lambda: EmbeddingModuleI64()) def EmbeddingModuleI64_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3))) + module.forward(tu.randint(3, 3, high=100)) # ============================================================================== @@ -762,7 +762,7 @@ class EmbeddingModuleI32(torch.nn.Module): @register_test_case(module_factory=lambda: EmbeddingModuleI32()) def EmbeddingModuleI32_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3)).to(torch.int32)) + module.forward(tu.randint(3, 3, high=100).to(torch.int32)) # ============================================================================== @@ -787,7 +787,7 @@ class EmbeddingModuleI32Static(torch.nn.Module): @register_test_case(module_factory=lambda: EmbeddingModuleI32Static()) def EmbeddingModuleI32Static_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 3)).to(torch.int32)) + module.forward(tu.randint(3, 3, high=100).to(torch.int32)) # ============================================================================== @@ -813,7 +813,7 @@ class EmbeddingModule1DIndices(torch.nn.Module): @register_test_case(module_factory=lambda: EmbeddingModule1DIndices()) def EmbeddingModule1DIndices_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3,)).to(torch.int32)) + module.forward(tu.randint(3, high=100).to(torch.int32)) # ============================================================================== @@ -1331,7 +1331,7 @@ class DropoutEvalIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: DropoutEvalIntModule()) def DropoutEvalIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(5, 10, (3, 4))) + module.forward(tu.randint(3, 4, low=5, high=10)) # ============================================================================== @@ -1420,7 +1420,7 @@ class NumelZeroRankModule(torch.nn.Module): @register_test_case(module_factory=lambda: NumelZeroRankModule()) def NumelZeroRankModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, [])) + module.forward(tu.randint(high=10)) # ============================================================================== @@ -1646,7 +1646,7 @@ class ReturnTwoTensorF32I64(torch.nn.Module): @register_test_case(module_factory=lambda: ReturnTwoTensorF32I64()) def ReturnTwoTensorF32I64_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(5, (2, 3))) + module.forward(tu.rand(2, 3), tu.randint(2, 3, high=5)) # ============================================================================== @@ -1669,7 +1669,7 @@ class IndexTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorModule()) def IndexTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5), torch.randint(4, (2, 3))) + module.forward(tu.rand(5), tu.randint(2, 3, high=4)) # ============================================================================== @@ -1692,7 +1692,7 @@ class IndexTensorModule3dInput(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorModule3dInput()) def IndexTensorModule3dInput_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(3, (2, 3))) + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) # ============================================================================== @@ -1715,7 +1715,7 @@ class IndexTensorSelectDimModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorSelectDimModule()) def IndexTensorSelectDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4, 6), torch.randint(3, (2, 3))) + module.forward(tu.rand(2, 4, 6), tu.randint(2, 3, high=3)) # ============================================================================== @@ -1738,7 +1738,7 @@ class IndexTensorMultiInput(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorMultiInput()) def IndexTensorMultiInput_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(3, (3, 3)), torch.randint(3, (3,))) + module.forward(tu.rand(5, 4, 3), tu.randint(3, 3, high=3), tu.randint(3, high=3)) # ============================================================================== @@ -1762,7 +1762,7 @@ class IndexTensorMultiInputOneDim(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorMultiInputOneDim()) def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), torch.randint(3, (3,))) + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), tu.randint(3, high=3)) # ============================================================================== @@ -1791,8 +1791,8 @@ class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module): @register_test_case( module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic()) def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), - torch.randint(3, (3, ))) + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), + tu.randint(3, high=3)) # ============================================================================== @@ -1822,8 +1822,8 @@ class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module): module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic()) def IndexTensorMultiInputNonContiguousOneDimDynamic_basic( module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)), - torch.randint(3, (3, ))) + module.forward(tu.rand(5, 4, 3), tu.randint(6, 1, high=4), + tu.randint(3, high=3)) # ============================================================================== @@ -1852,8 +1852,8 @@ class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module): @register_test_case( module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic()) def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3), torch.randint(2, (6, 2)), - torch.randint(3, (2, ))) + module.forward(tu.rand(5, 4, 3), tu.randint(6, 2, high=2), + tu.randint(2, high=3)) # ============================================================================== @@ -1880,8 +1880,8 @@ class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module): IndexTensorMultiInputNonContiguousMultipleStaticDims()) def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic( module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 1)), - torch.randint(1, (1, 3)), torch.randint(1, (4, 3))) + module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3), + tu.randint(1, 3, high=1), tu.randint(4, 3, high=1)) # ============================================================================== @@ -1905,7 +1905,7 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorMultiInputNonContiguous()) def IndexTensorMultiInputNonContiguous_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 2)), torch.randint(1, (4, 2,))) + module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 2, high=3), tu.randint(4, 2, high=1)) # ============================================================================== @@ -1931,9 +1931,9 @@ class IndexTensorMultiInputThreeIndexers(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorMultiInputThreeIndexers()) def IndexTensorMultiInputThreeIndexers_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 4, 4, 5, 3), - torch.randint(3, (8, 4, 2,)), - torch.randint(4, (8, 1, 1,)), - torch.randint(2, (4, 2,))) + tu.randint(8, 4, 2, high=3), + tu.randint(8, 1, 1, high=4), + tu.randint(4, 2, high=2)) # ============================================================================== @@ -1957,7 +1957,7 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module): @register_test_case(module_factory=lambda: IndexTensorMultiInputContiguousCenter()) def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (2, 2)), torch.randint(2, [2])) + module.forward(tu.rand(5, 4, 3, 2), tu.randint(2, 2, high=3), tu.randint(2, high=2)) # ============================================================================== @@ -2089,7 +2089,7 @@ class HardTanhIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: HardTanhIntModule()) def HardTanhIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-5, 5, (100, 100))) + module.forward(tu.randint(100, 100, low=-5, high=5)) # ============================================================================== @@ -2111,7 +2111,7 @@ class BincountModule(torch.nn.Module): @register_test_case(module_factory=lambda: BincountModule()) def BincountModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (1000, ))) + module.forward(tu.randint(1000, high=10)) # ============================================================================== @@ -2133,7 +2133,7 @@ class BincountStaticSizeModule(torch.nn.Module): @register_test_case(module_factory=lambda: BincountStaticSizeModule()) def BincountStaticSizeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (200, ))) + module.forward(tu.randint(200, high=100)) # ============================================================================== @@ -2155,7 +2155,7 @@ class BincountMinlengthModule(torch.nn.Module): @register_test_case(module_factory=lambda: BincountMinlengthModule()) def BincountMinlengthModule_basic(module, tu: TestUtils): - module.forward(torch.randint(5, (20, ))) + module.forward(tu.randint(20, high=5)) # ============================================================================== @@ -2198,8 +2198,8 @@ class ExpandAsIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ExpandAsIntModule()) def ExpandAsIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (1, 1, 1)), - torch.randint(200, (4, 5, 6))) + module.forward(tu.randint(1, 1, 1, high=100), + tu.randint(4, 5, 6, high=200)) # ============================================================================== @@ -2262,7 +2262,7 @@ class CopyWithDifferentDTypesModule(torch.nn.Module): @register_test_case(module_factory=lambda: CopyWithDifferentDTypesModule()) def CopyWithDifferentDTypesModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 2, 4)), tu.rand(3, 2, 4)) + module.forward(tu.randint(3, 2, 4, high=100), tu.rand(3, 2, 4)) class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module): @@ -2283,7 +2283,7 @@ class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module): @register_test_case( module_factory=lambda: CopyWithDifferentDTypesAndSizesModule()) def CopyWithDifferentDTypesAndSizesModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4), torch.randint(1000, (3, 2, 1))) + module.forward(tu.rand(3, 2, 4), tu.randint(3, 2, 1, high=1000)) # ============================================================================== @@ -2451,7 +2451,7 @@ class ScalarImplicitIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ScalarImplicitIntModule()) def ScalarImplicitIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100)) # ============================================================================== @@ -2517,7 +2517,7 @@ class BaddbmmDifferentDtypesModule(torch.nn.Module): @register_test_case(module_factory=lambda: BaddbmmDifferentDtypesModule()) def BaddbmmDifferentDtypesModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4, 5)), tu.rand(3, 4, 6), + module.forward(tu.randint(3, 4, 5, high=10), tu.rand(3, 4, 6), tu.rand(3, 6, 5)) diff --git a/python/torch_mlir_e2e_test/test_suite/cast.py b/python/torch_mlir_e2e_test/test_suite/cast.py index 834066949..bbf732329 100644 --- a/python/torch_mlir_e2e_test/test_suite/cast.py +++ b/python/torch_mlir_e2e_test/test_suite/cast.py @@ -26,7 +26,7 @@ class TensorToIntZeroRank(torch.nn.Module): @register_test_case(module_factory=lambda: TensorToIntZeroRank()) def TensorToIntZeroRank_basic(module, tu: TestUtils): - module.forward(torch.randint(10, ())) + module.forward(tu.randint(high=10)) # ============================================================================== @@ -45,7 +45,7 @@ class TensorToInt(torch.nn.Module): @register_test_case(module_factory=lambda: TensorToInt()) def TensorToInt_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (1, 1))) + module.forward(tu.randint(1, 1, high=10)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 809d37f2e..149122540 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -346,7 +346,7 @@ class EmptyLikeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: EmptyLikeIntModule()) def EmptyLikeModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10)) class EmptyLikeMemoryFormatModule(torch.nn.Module): @@ -446,7 +446,7 @@ class ZerosLikeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ZerosLikeIntModule()) def ZerosLikeModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10)) class ZerosLikeFloatModule(torch.nn.Module): @@ -525,7 +525,7 @@ class OnesLikeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: OnesLikeIntModule()) def OnesLikeModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10)) class OnesLikeFloatModule(torch.nn.Module): @@ -702,7 +702,7 @@ class NewZerosModuleFloat2D(torch.nn.Module): @register_test_case(module_factory=lambda: NewZerosModuleFloat2D()) def NewZerosModuleFloat2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3, 4))) + module.forward(tu.randint(2, 3, 4, high=10)) class NewZerosModuleFloat3D(torch.nn.Module): @@ -721,7 +721,7 @@ class NewZerosModuleFloat3D(torch.nn.Module): @register_test_case(module_factory=lambda: NewZerosModuleFloat3D()) def NewZerosModuleFloat3D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewZerosModuleFalsePinMemory(torch.nn.Module): @@ -742,7 +742,7 @@ class NewZerosModuleFalsePinMemory(torch.nn.Module): @register_test_case(module_factory=lambda: NewZerosModuleFalsePinMemory()) def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) # ============================================================================== @@ -821,7 +821,7 @@ class NewOnesModuleFloat2D(torch.nn.Module): @register_test_case(module_factory=lambda: NewOnesModuleFloat2D()) def NewOnesModuleFloat2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3, 4))) + module.forward(tu.randint(2, 3, 4, high=10)) class NewOnesModuleFloat3D(torch.nn.Module): @@ -840,7 +840,7 @@ class NewOnesModuleFloat3D(torch.nn.Module): @register_test_case(module_factory=lambda: NewOnesModuleFloat3D()) def NewOnesModuleFloat3D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewOnesModuleFalsePinMemory(torch.nn.Module): @@ -861,7 +861,7 @@ class NewOnesModuleFalsePinMemory(torch.nn.Module): @register_test_case(module_factory=lambda: NewOnesModuleFalsePinMemory()) def NewOnesModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) # ============================================================================== @@ -1016,7 +1016,7 @@ class FullLikeModuleInt2D(torch.nn.Module): @register_test_case(module_factory=lambda: FullLikeModuleInt2D()) def FullLikeModuleInt2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5))) + module.forward(tu.randint(4, 5, high=10)) class FullLikeModuleInt3D(torch.nn.Module): @@ -1035,7 +1035,7 @@ class FullLikeModuleInt3D(torch.nn.Module): @register_test_case(module_factory=lambda: FullLikeModuleInt3D()) def FullLikeModuleInt3D_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4, 5)).to(torch.int32)) + module.forward(tu.randint(10, 4, 5, high=100).to(torch.int32)) class FullLikeModuleInt2DStatic(torch.nn.Module): @@ -1054,7 +1054,7 @@ class FullLikeModuleInt2DStatic(torch.nn.Module): @register_test_case(module_factory=lambda: FullLikeModuleInt2DStatic()) def FullLikeModuleInt2DStatic_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5))) + module.forward(tu.randint(4, 5, high=10)) class FullLikeModuleFloat2D(torch.nn.Module): @@ -1133,7 +1133,7 @@ class FullLikeModuleFalsePinMemory(torch.nn.Module): @register_test_case(module_factory=lambda: FullLikeModuleFalsePinMemory()) def FullLikeModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4))) + module.forward(tu.randint(10, 4, high=100)) # ============================================================================== @@ -1174,7 +1174,7 @@ class ZeroInt32Module(torch.nn.Module): @register_test_case(module_factory=lambda: ZeroInt32Module()) def ZeroInt32Module_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4), dtype=torch.int32)) + module.forward(tu.randint(10, 4, high=100).to(dtype=torch.int32)) class ZeroInt64Module(torch.nn.Module): @@ -1193,7 +1193,7 @@ class ZeroInt64Module(torch.nn.Module): @register_test_case(module_factory=lambda: ZeroInt64Module()) def ZeroInt64Module_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, 4))) + module.forward(tu.randint(10, 4, high=100)) # ============================================================================== @@ -1274,7 +1274,7 @@ class NewEmptyModuleFloat2D(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyModuleFloat2D()) def NewEmptyModuleFloat2D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3, 4))) + module.forward(tu.randint(2, 3, 4, high=10)) class NewEmptyModuleFloat3D(torch.nn.Module): @@ -1294,7 +1294,7 @@ class NewEmptyModuleFloat3D(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyModuleFloat3D()) def NewEmptyModuleFloat3D_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewEmptyModuleFalsePinMemory(torch.nn.Module): @@ -1315,7 +1315,7 @@ class NewEmptyModuleFalsePinMemory(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyModuleFalsePinMemory()) def NewEmptyModuleFalsePinMemory_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3))) + module.forward(tu.randint(2, 3, high=10)) class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module): @@ -1354,7 +1354,7 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyModuleNonDefaultIntDtype()) def NewEmptyModuleNonDefaultIntDtype_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + module.forward(tu.randint(2, 3, high=10).to(torch.int32)) class NewEmptyModuleLayoutIntDtype(torch.nn.Module): @@ -1373,7 +1373,7 @@ class NewEmptyModuleLayoutIntDtype(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyModuleLayoutIntDtype()) def NewEmptyModuleLayoutIntDtype_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3)).to(torch.int32)) + module.forward(tu.randint(2, 3, high=10).to(torch.int32)) # ============================================================================== @@ -1397,7 +1397,7 @@ class MaskedFillScalarDefaultModule(torch.nn.Module): @register_test_case(module_factory=lambda: MaskedFillScalarDefaultModule()) def MaskedFillScalarDefaultModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillScalarIntValueModule(torch.nn.Module): @@ -1418,7 +1418,7 @@ class MaskedFillScalarIntValueModule(torch.nn.Module): @register_test_case(module_factory=lambda: MaskedFillScalarIntValueModule()) def MaskedFillScalarIntValueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillScalarFloatValueModule(torch.nn.Module): @@ -1438,8 +1438,8 @@ class MaskedFillScalarFloatValueModule(torch.nn.Module): @register_test_case(module_factory=lambda: MaskedFillScalarFloatValueModule()) def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 10, (2, 3)), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + module.forward(tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool)) class MaskedFillTensorFloatValueModule(torch.nn.Module): @@ -1460,5 +1460,5 @@ class MaskedFillTensorFloatValueModule(torch.nn.Module): @register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule()) def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 10, (2, 3)), - torch.randint(0, 2, (2, 3)).to(dtype=torch.bool), tu.rand()) + module.forward(tu.randint(2, 3, low=-10, high=10), + tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand()) diff --git a/python/torch_mlir_e2e_test/test_suite/control_flow.py b/python/torch_mlir_e2e_test/test_suite/control_flow.py index df1912e3e..6893f7137 100644 --- a/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -32,7 +32,7 @@ class TorchPrimLoopForLikeModule(torch.nn.Module): @register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule()) def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(0, 10, (6, 8))) + module.forward(tu.randint(6, 8, high=10)) # ============================================================================== class TorchPrimLoopWhileLikeModule(torch.nn.Module): @@ -54,4 +54,4 @@ class TorchPrimLoopWhileLikeModule(torch.nn.Module): @register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule()) def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils): - module.forward(torch.randint(0, 10, (6, 8))) + module.forward(tu.randint(6, 8, high=10)) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index fa45fa19f..15912be5a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -57,7 +57,7 @@ class ElementwiseUnaryIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseUnaryIntModule()) def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -428,7 +428,7 @@ class ElementwiseSigmoidIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseSigmoidIntModule()) def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32)) + module.forward(tu.randint(3, 5, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -474,7 +474,7 @@ class ElementwiseMinimumIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMinimumIntModule()) def ElementwiseMinimumIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10)) # ============================================================================== @@ -520,7 +520,7 @@ class ElementwiseMaximumIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMaximumIntModule()) def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5))) + module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10)) # ============================================================================== @@ -663,7 +663,7 @@ class RsubIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: RsubIntModule()) def RsubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4))) + module.forward(tu.randint(3, 4, high=100)) # ============================================================================== @@ -685,7 +685,7 @@ class RsubIntModule_noalpha(torch.nn.Module): @register_test_case(module_factory=lambda: RsubIntModule_noalpha()) def RsubIntModule_noalpha_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4))) + module.forward(tu.randint(3, 4, high=100)) # ============================================================================== @@ -707,7 +707,7 @@ class ElementwiseMulScalarIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMulScalarIntModule()) def ElementwiseMulScalarModule_int(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4))) + module.forward(tu.randint(3, 4, high=10)) # ============================================================================== @@ -751,7 +751,7 @@ class ElementwiseMulScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMulScalarModule()) def ElementwiseMulScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, high=10).to(torch.int32)) # ============================================================================== @@ -798,7 +798,7 @@ class ElementwiseMulTensorIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseMulTensorIntModule()) def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): module.forward( - torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4])) + tu.randint(4, high=10).type(torch.int32), tu.randint(4, high=10)) # ============================================================================== @@ -845,7 +845,7 @@ class ElementwiseAtan2TensorIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule()) def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils): module.forward( - torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4])) + tu.randint(4, low=1, high=10).type(torch.int32), tu.randint(4, low=1, high=10)) # ============================================================================== @@ -868,8 +868,8 @@ class ElementwiseAtan2FloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule()) def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32), - tu.rand(4, 4).double()) + module.forward(tu.randint(4, 4, low=1, high=10).to(torch.int32), + tu.rand(4, 4).double()) # ============================================================================== @@ -913,7 +913,7 @@ class ElementwiseLogIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseLogIntModule()) def ElementwiseLogIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -978,7 +978,7 @@ class ElementwiseErfIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseErfIntModule()) def ElementwiseErfIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1022,7 +1022,7 @@ class ElementwiseSqrtIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseSqrtIntModule()) def ElementwiseSqrtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1170,7 +1170,7 @@ class ElementwiseLog2IntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseLog2IntModule()) def ElementwiseLog2IntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1214,7 +1214,7 @@ class ElementwiseRsqrtIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseRsqrtIntModule()) def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1280,7 +1280,7 @@ class ElementwiseReciprocalIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseReciprocalIntModule()) def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (4,), dtype=torch.int32)) + module.forward(tu.randint(4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1323,7 +1323,7 @@ class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float()) def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3,), dtype=torch.int32)) + module.forward(tu.randint(3, high=10).to(torch.int32)) # ============================================================================== @@ -1366,7 +1366,7 @@ class ElementwiseRemainderScalarModule_Int(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int()) def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 2), dtype=torch.int32)) + module.forward(tu.randint(3, 2, high=10).to(torch.int32)) # ============================================================================== @@ -1478,8 +1478,8 @@ class ElementwiseAndIntegerModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAndIntegerModule()) def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): module.forward( - torch.randint(-10, 10, (3, 4)).to(torch.int32), - torch.randint(-10, 10, (3, 4))) + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(3, 4, low=-10, high=10)) # ============================================================================== @@ -1501,7 +1501,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule()) def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, high=10).to(dtype=torch.int32)) # ============================================================================== @@ -1545,7 +1545,7 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module()) def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 4))) + module.forward(tu.randint(3, 4, high=10)) # ============================================================================== @@ -1567,7 +1567,7 @@ class ElementwiseAddScalarIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule()) def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (2, 3), dtype=torch.int32)) + module.forward(tu.randint(2, 3, high=10).to(torch.int32)) # ============================================================================== @@ -1677,7 +1677,7 @@ class ElementwiseExpIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseExpIntModule()) def ElementwiseExpIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1721,7 +1721,7 @@ class ElementwiseExpm1IntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1765,7 +1765,7 @@ class ElementwiseSinIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseSinIntModule()) def ElementwiseSinIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1809,7 +1809,7 @@ class ElementwiseCosIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseCosIntModule()) def ElementwiseCosIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== @@ -1924,7 +1924,7 @@ class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule()) def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils): - module.forward(torch.randint(3, 10, (2, 3, 4, 5)), torch.randint(10, 100, (2, 3, 4, 5))) + module.forward(tu.randint(2, 3, 4, 5, low=3, high=10), tu.randint(2, 3, 4, 5, low=10, high=100)) # ============================================================================== @@ -1962,7 +1962,7 @@ class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule()) def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils): - module.forward(torch.neg(torch.randint(3, 10, (2, 3, 4, 5))), torch.neg(torch.randint(10, 100, (2, 3, 4, 5)))) + module.forward(torch.neg(tu.randint(2, 3, 4, 5, low=3, high=10)), torch.neg(tu.randint(2, 3, 4, 5, low=10, high=100))) # ============================================================================== @@ -1981,7 +1981,7 @@ class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule()) def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils): - module.forward(torch.randint(3, (3,)), torch.randint(3, (4, 3))) + module.forward(tu.randint(3, high=3), tu.randint(4, 3, high=3)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 3e7f8a79a..defd8bc51 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -45,7 +45,7 @@ class ElementwiseGtIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule()) def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -64,7 +64,7 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule()) def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -102,7 +102,7 @@ class ElementwiseGeIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseGeIntScalarModule()) def ElementwiseGeIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -121,7 +121,7 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseGeMixedIntScalarModule()) def ElementwiseGeMixedIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -180,7 +180,7 @@ class ElementwiseGtIntTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseGtIntTensorModule()) def ElementwiseGtIntTensorModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, ))) + module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) # ============================================================================== @@ -218,7 +218,7 @@ class ElementwiseLtIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseLtIntScalarModule()) def ElementwiseLtIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -238,7 +238,7 @@ class ElementwiseLtDiffWidthScalarModule(torch.nn.Module): @register_test_case( module_factory=lambda: ElementwiseLtDiffWidthScalarModule()) def ElementwiseLtDiffWidthScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -276,7 +276,7 @@ class ElementwiseLeIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseLeIntScalarModule()) def ElementwiseLeIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4))) + module.forward(tu.randint(3, 4, low=-10, high=15)) # ============================================================================== @@ -295,7 +295,7 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseLeMixedIntScalarModule()) def ElementwiseLeMixedIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-10, 15, (3, 4)).to(torch.int32)) + module.forward(tu.randint(3, 4, low=-10, high=15).to(torch.int32)) # ============================================================================== @@ -354,7 +354,7 @@ class ElementwiseLtIntTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseLtIntTensorModule()) def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (3, 5)), torch.randint(10, (5, ))) + module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) # ============================================================================== @@ -393,7 +393,7 @@ class ElementwiseEqIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseEqIntScalarModule()) def ElementwiseEqIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (5, 8))) + module.forward(tu.randint(5, 8, low=2, high=4)) # ============================================================================== @@ -413,7 +413,7 @@ class ElementwiseEqDiffWidthScalarModule(torch.nn.Module): @register_test_case( module_factory=lambda: ElementwiseEqDiffWidthScalarModule()) def ElementwiseEqDiffWidthScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (5, 8)).to(torch.int32)) + module.forward(tu.randint(5, 8, low=2, high=4).to(torch.int32)) # ============================================================================== @@ -455,7 +455,7 @@ class ElementwiseEqIntTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseEqIntTensorModule()) def ElementwiseEqIntTensorModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (8, 5)), torch.randint(2, 4, (5, ))) + module.forward(tu.randint(8, 5, low=2, high=4), tu.randint(5, low=2, high=4)) # ============================================================================== @@ -494,7 +494,7 @@ class ElementwiseNeIntScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseNeIntScalarModule()) def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils): - module.forward(torch.randint(2, 4, (8, 5))) + module.forward(tu.randint(8, 5, low=2, high=4)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py index 82034d771..1cec5345a 100644 --- a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py +++ b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py @@ -90,17 +90,12 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module): @register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature()) def HBC_basic(module, tu: TestUtils): logits = torch.rand(NUM_LOGITS, dtype=torch.float) - segment_lengths: Tensor = torch.randint( - 0, 2, (NUM_LOGITS,), dtype=torch.int) + segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int) segment_offsets: Tensor = torch.cumsum(segment_lengths, 0) segment_offsets: Tensor = torch.cat( (torch.tensor([0]), segment_offsets), 0) num_values: int = int(torch.sum(segment_lengths).item()) - segment_values: Tensor = torch.randint( - 0, - NUM_SEGMENTS, - (num_values,), - ) + segment_values: Tensor = tu.randint(num_values, high=NUM_SEGMENTS) segment_values = torch.cat( (segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0) module.forward(segment_values.int(), segment_offsets.int(), logits) diff --git a/python/torch_mlir_e2e_test/test_suite/index_put.py b/python/torch_mlir_e2e_test/test_suite/index_put.py index e12b2627c..70e8b88c7 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_put.py +++ b/python/torch_mlir_e2e_test/test_suite/index_put.py @@ -34,7 +34,7 @@ class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule()) def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250)) + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module): @@ -59,7 +59,7 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule()) def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): @@ -84,7 +84,7 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule()) def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -113,8 +113,8 @@ class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule()) def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )), - torch.randint(10000, (300, ))) + module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), + tu.randint(300, high=10000)) # ============================================================================== @@ -142,7 +142,7 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl1DFloatAccumulateModule()) def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500)) + module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module): @@ -167,7 +167,7 @@ class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl2DFloatAccumulateModule()) def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module): @@ -192,7 +192,7 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutImpl3DFloatAccumulateModule()) def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -220,8 +220,8 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule()) def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )), - torch.randint(1000, (10, ))) + module.forward(tu.randint(10, high=100), tu.randint(10, high=10), + tu.randint(10, high=1000)) # ============================================================================== @@ -248,7 +248,7 @@ class IndexPut1DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPut1DFloatNonAccumulateModule()) def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250)) + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPut2DFloatNonAccumulateModule(torch.nn.Module): @@ -272,7 +272,7 @@ class IndexPut2DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPut2DFloatNonAccumulateModule()) def IndexPut2DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPut3DFloatNonAccumulateModule(torch.nn.Module): @@ -296,7 +296,7 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPut3DFloatNonAccumulateModule()) def IndexPut3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -323,8 +323,8 @@ class IndexPut1DIntNonAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule()) def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )), - torch.randint(10000, (300, ))) + module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), + tu.randint(300, high=10000)) class IndexPut2DIntNonAccumulateModule(torch.nn.Module): @@ -347,8 +347,8 @@ class IndexPut2DIntNonAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut2DIntNonAccumulateModule()) def IndexPut2DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPut3DIntNonAccumulateModule(torch.nn.Module): @@ -371,8 +371,8 @@ class IndexPut3DIntNonAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut3DIntNonAccumulateModule()) def IndexPut3DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) # ============================================================================== @@ -398,7 +398,7 @@ class IndexPut1DFloatAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule()) def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500)) + module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPut2DFloatAccumulateModule(torch.nn.Module): @@ -421,7 +421,7 @@ class IndexPut2DFloatAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut2DFloatAccumulateModule()) def IndexPut2DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPut3DFloatAccumulateModule(torch.nn.Module): @@ -444,7 +444,7 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut3DFloatAccumulateModule()) def IndexPut3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -471,8 +471,8 @@ class IndexPut1DIntAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule()) def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )), - torch.randint(1000, (10, ))) + module.forward(tu.randint(10, high=100), tu.randint(10, high=10), + tu.randint(10, high=1000)) class IndexPut2DIntAccumulateModule(torch.nn.Module): @@ -495,8 +495,8 @@ class IndexPut2DIntAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut2DIntAccumulateModule()) def IndexPut2DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPut3DIntAccumulateModule(torch.nn.Module): @@ -519,8 +519,8 @@ class IndexPut3DIntAccumulateModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexPut3DIntAccumulateModule()) def IndexPut3DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) # ============================================================================== @@ -548,7 +548,7 @@ class IndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DFloatNonAccumulateModule()) def IndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(100), torch.randint(100, (250, )), tu.rand(250)) + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module): @@ -572,7 +572,7 @@ class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DFloatNonAccumulateModule()) def IndexPutHackedTwin2DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module): @@ -596,7 +596,7 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DFloatNonAccumulateModule()) def IndexPutHackedTwin3DFloatNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -624,8 +624,8 @@ class IndexPutHackedTwin1DIntNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DIntNonAccumulateModule()) def IndexPutHackedTwin1DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (200, )), torch.randint(100, (300, )), - torch.randint(10000, (300, ))) + module.forward(tu.randint(200, high=1000), tu.randint(300, high=100), + tu.randint(300, high=10000)) class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module): @@ -649,8 +649,8 @@ class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DIntNonAccumulateModule()) def IndexPutHackedTwin2DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module): @@ -674,8 +674,8 @@ class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DIntNonAccumulateModule()) def IndexPutHackedTwin3DIntNonAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) # ============================================================================== @@ -700,7 +700,7 @@ class IndexPutHackedTwin1DFloatAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DFloatAccumulateModule()) def IndexPutHackedTwin1DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1000), torch.randint(10, (500, )), tu.rand(500)) + module.forward(tu.rand(1000), tu.randint(500, high=10), tu.rand(500)) class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module): @@ -722,7 +722,7 @@ class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DFloatAccumulateModule()) def IndexPutHackedTwin2DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8), torch.randint(4, (5, )), tu.rand(5, 8)) + module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8)) class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module): @@ -744,7 +744,7 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DFloatAccumulateModule()) def IndexPutHackedTwin3DFloatAccumulateModule_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), torch.randint(4, (5, )), + module.forward(tu.rand(10, 8, 6), tu.randint(5, high=4), tu.rand(5, 8, 6)) @@ -770,8 +770,8 @@ class IndexPutHackedTwin1DIntAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin1DIntAccumulateModule()) def IndexPutHackedTwin1DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (10, )), torch.randint(10, (10, )), - torch.randint(1000, (10, ))) + module.forward(tu.randint(10, high=100), tu.randint(10, high=10), + tu.randint(10, high=1000)) class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module): @@ -793,8 +793,8 @@ class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin2DIntAccumulateModule()) def IndexPutHackedTwin2DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8))) + module.forward(tu.randint(10, 8, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, high=1000)) class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module): @@ -816,5 +816,5 @@ class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module): @register_test_case( module_factory=lambda: IndexPutHackedTwin3DIntAccumulateModule()) def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1000, (10, 8, 6)), torch.randint(4, (5, )), - torch.randint(1000, (5, 8, 6))) + module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), + tu.randint(5, 8, 6, high=1000)) diff --git a/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/python/torch_mlir_e2e_test/test_suite/nll_loss.py index edbb8f444..bc70fbbec 100644 --- a/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -34,7 +34,7 @@ class NllLossModule(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule()) def NllLossModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_mean(torch.nn.Module): @@ -58,7 +58,7 @@ class NllLossModule_mean(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_mean()) def NllLossModule_mean_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_sum(torch.nn.Module): @@ -82,7 +82,7 @@ class NllLossModule_sum(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_sum()) def NllLossModule_sum_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_1D(torch.nn.Module): @@ -106,7 +106,7 @@ class NllLossModule_1D(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_1D()) def NllLossModule_1D_basic(module, tu: TestUtils): - module.forward(tu.rand(3), torch.randint(0, 3, ())) + module.forward(tu.rand(3), tu.randint(high=3)) class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): @@ -131,7 +131,7 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds()) def NllLossModule_ignore_index_out_of_bounds_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3), torch.randint(0, 3, (2,))) + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) class NllLossModule_backward(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index 26a18b0df..9858a304e 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -489,7 +489,7 @@ class MaxPool2dWithIndicesBackwardStatic4DModule(torch.nn.Module): module_factory=lambda: MaxPool2dWithIndicesBackwardStatic4DModule()) def MaxPool2dWithIndicesBackwardStatic4DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), - torch.randint(16, (2, 4, 7, 6))) + tu.randint(2, 4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module): @@ -519,7 +519,7 @@ class MaxPool2dWithIndicesBackwardStatic3DModule(torch.nn.Module): module_factory=lambda: MaxPool2dWithIndicesBackwardStatic3DModule()) def MaxPool2dWithIndicesBackwardStatic3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 7, 6), tu.rand(4, 6, 5), - torch.randint(16, (4, 7, 6))) + tu.randint(4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module): @@ -549,7 +549,7 @@ class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module): module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic4DModule()) def MaxPool2dWithIndicesBackwardDynamic4DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 7, 6), tu.rand(2, 4, 6, 5), - torch.randint(16, (2, 4, 7, 6))) + tu.randint(2, 4, 7, 6, high=16)) class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module): @@ -579,7 +579,7 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module): module_factory=lambda: MaxPool2dWithIndicesBackwardDynamic3DModule()) def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 6), tu.rand(2, 6, 5), - torch.randint(16, (2, 7, 6))) + tu.randint(2, 7, 6, high=16)) # ============================================================================== @@ -632,7 +632,7 @@ class AvgPool2dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: AvgPool2dIntModule()) def AvgPool2dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (2, 4, 20, 20))) + module.forward(tu.randint(2, 4, 20, 20, high=100)) class AvgPool2dStaticModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 3fcf29534..6481fd553 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -140,7 +140,7 @@ class ReduceSumUnsignedIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceSumUnsignedIntModule()) def ReduceSumUnsignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(0, 100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, low=0, high=100)) # ============================================================================== @@ -159,7 +159,7 @@ class ReduceSumSignedIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceSumSignedIntModule()) def ReduceSumSignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) # ============================================================================== @@ -178,7 +178,7 @@ class ReduceSumDtypeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceSumDtypeIntModule()) def ReduceSumDtypeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5)).to(torch.int32)) + module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) # ============================================================================== @@ -197,7 +197,7 @@ class ReduceSumDimIntListIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceSumDimIntListIntModule()) def ReduceSumDimIntListIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== @@ -216,7 +216,7 @@ class ReduceSumDimIntListDtypeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeIntModule()) def ReduceSumDimIntListDtypeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5)).to(torch.int32)) + module.forward(tu.randint(3, 4, 5, high=100).to(torch.int32)) # ============================================================================== @@ -235,7 +235,7 @@ class ReduceSumDimIntListKeepDimIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimIntModule()) def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== @@ -383,7 +383,7 @@ class ReduceMaxSignedIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceMaxSignedIntModule()) def ReduceMaxSignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, low=-100, high=100)) # ============================================================================== @@ -401,7 +401,7 @@ class ReduceMaxUnsignedIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule()) def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(100, (3, 4, 5))) + module.forward(tu.randint(3, 4, 5, high=100)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index f79ebc206..081690efc 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -29,7 +29,7 @@ class AddIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: AddIntModule()) def AddIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -52,7 +52,7 @@ class SubIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: SubIntModule()) def SubIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -98,7 +98,7 @@ class MulIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: MulIntModule()) def MulIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -172,7 +172,7 @@ class SqrtIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: SqrtIntModule()) def SqrtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, ())) + module.forward(tu.randint(high=10)) class SqrtIntConstantModule(torch.nn.Module): @@ -273,7 +273,7 @@ class BoolIntFalseModule(torch.nn.Module): @register_test_case(module_factory=lambda: BoolIntFalseModule()) def BoolIntFalseModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 100, ())) + module.forward(tu.randint(low=1, high=100)) class BoolIntTrueModule(torch.nn.Module): @@ -292,7 +292,7 @@ class BoolIntTrueModule(torch.nn.Module): @register_test_case(module_factory=lambda: BoolIntTrueModule()) def BoolIntTrueModule_basic(module, tu: TestUtils): - module.forward(torch.randint(1, 100, ())) + module.forward(tu.randint(low=1, high=100)) class BoolIntConstantModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 52266cb9d..e062b669b 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -29,7 +29,7 @@ class NeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: NeIntModule()) def NeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -52,7 +52,7 @@ class EqIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: EqIntModule()) def EqIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -75,7 +75,7 @@ class GtIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: GtIntModule()) def GtIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -98,7 +98,7 @@ class GeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: GeIntModule()) def GeIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + module.forward(tu.randint(low=-100, high=100), tu.randint(low=-100, high=100)) # ============================================================================== @@ -144,7 +144,7 @@ class GeFloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: GeFloatIntModule()) def GeFloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) # ============================================================================== @@ -167,7 +167,7 @@ class NeFloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: NeFloatIntModule()) def NeFloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) # ============================================================================== @@ -190,4 +190,4 @@ class GtFloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: GtFloatIntModule()) def GtFloatIntModule_basic(module, tu: TestUtils): - module.forward(torch.randn(()).double(), torch.randint(-100, 100, ())) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 772d75e17..c576d900c 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -229,7 +229,7 @@ class SelectIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: SelectIntModule()) def SelectIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (5,5))) + module.forward(tu.randint(5,5, high=10)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py b/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py index 1f74c9dc8..7c53c1a1b 100644 --- a/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py +++ b/python/torch_mlir_e2e_test/test_suite/table_batch_embedding.py @@ -54,7 +54,7 @@ class TableBatchEmbeddingModule(torch.nn.Module): @register_test_case(module_factory=lambda: TableBatchEmbeddingModule()) def TableBatchEmbeddingModule_basic(module, tu: TestUtils): - indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,)) + indices = tu.randint(NUM_TABLES * BATCH_SIZE * BAG_SIZE, high=NUM_EMBEDDINGS) offsets = torch.cumsum( torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0) module.forward(indices, offsets) diff --git a/python/torch_mlir_e2e_test/test_suite/threshold.py b/python/torch_mlir_e2e_test/test_suite/threshold.py index d7a34a89a..ac784abb4 100644 --- a/python/torch_mlir_e2e_test/test_suite/threshold.py +++ b/python/torch_mlir_e2e_test/test_suite/threshold.py @@ -27,7 +27,7 @@ class Threshold1dIntI32Module(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold1dIntI32Module()) def Threshold1dIntI32Module_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4,), dtype=torch.int32)) + module.forward(tu.randint(4, high=10).to(torch.int32)) class Threshold1dIntModule(torch.nn.Module): @@ -45,7 +45,7 @@ class Threshold1dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold1dIntModule()) def Threshold1dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4,))) + module.forward(tu.randint(4, high=10)) class Threshold2dIntModule(torch.nn.Module): @@ -63,7 +63,7 @@ class Threshold2dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold2dIntModule()) def Threshold2dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5))) + module.forward(tu.randint(4, 5, high=10)) class Threshold3dIntModule(torch.nn.Module): @@ -81,7 +81,7 @@ class Threshold3dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold3dIntModule()) def Threshold3dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5, 6))) + module.forward(tu.randint(4, 5, 6, high=10)) class Threshold1dFloatModule(torch.nn.Module): @@ -154,7 +154,7 @@ class ThresholdBackward1dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward1dIntModule()) def ThresholdBackward1dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4,)), torch.randint(8, (4,))) + module.forward(tu.randint(4, high=10), tu.randint(4, high=8)) class ThresholdBackward2dIntModule(torch.nn.Module): @@ -173,7 +173,7 @@ class ThresholdBackward2dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward2dIntModule()) def ThresholdBackward2dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5))) + module.forward(tu.randint(4, 5, high=10), tu.randint(4, 5, high=8)) class ThresholdBackward3dIntModule(torch.nn.Module): @@ -192,7 +192,7 @@ class ThresholdBackward3dIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward3dIntModule()) def ThresholdBackward3dIntModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6))) + module.forward(tu.randint(4, 5, 6, high=10), tu.randint(4, 5, 6, high=8)) class ThresholdBackward1dFloatModule(torch.nn.Module): @@ -268,7 +268,7 @@ class ThresholdBackward1dMixedModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule()) def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils): - module.forward(torch.randn(4), torch.randint(10, (4,))) + module.forward(torch.randn(4), tu.randint(4, high=10)) class ThresholdBackward2dMixedModule(torch.nn.Module): @@ -287,7 +287,7 @@ class ThresholdBackward2dMixedModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule()) def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils): - module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5)) + module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5)) class ThresholdBackward3dMixedModule(torch.nn.Module): @@ -306,4 +306,4 @@ class ThresholdBackward3dMixedModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule()) def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils): - module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6))) + module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10)) diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 0d455c3b4..2bd0e8b68 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -57,7 +57,7 @@ class TypeConversionI32ToI64Module(torch.nn.Module): @register_test_case(module_factory=lambda: TypeConversionI32ToI64Module()) def TypeConversionI32ToI64Module_basic(module, tu: TestUtils): - module.forward(torch.randint(5, [2, 3]).type(torch.int32)) + module.forward(tu.randint(2, 3, high=5).type(torch.int32)) class TypeConversionI64ToI32Module(torch.nn.Module): @@ -73,7 +73,7 @@ class TypeConversionI64ToI32Module(torch.nn.Module): @register_test_case(module_factory=lambda: TypeConversionI64ToI32Module()) def TypeConversionI64ToI32Module_basic(module, tu: TestUtils): - module.forward(torch.randint(5, [2, 3])) + module.forward(tu.randint(2, 3, high=5)) class TypeConversionI1ToI32Module(torch.nn.Module): @@ -89,7 +89,7 @@ class TypeConversionI1ToI32Module(torch.nn.Module): @register_test_case(module_factory=lambda: TypeConversionI1ToI32Module()) def TypeConversionI1ToI32Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) @@ -106,7 +106,7 @@ class TypeConversionI1ToI64Module(torch.nn.Module): @register_test_case(module_factory=lambda: TypeConversionI1ToI64Module()) def TypeConversionI1ToI64Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) @@ -123,7 +123,7 @@ class TypeConversionI1ToF32Module(torch.nn.Module): @register_test_case(module_factory=lambda: TypeConversionI1ToF32Module()) def TypeConversionI1ToF32Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) @@ -140,7 +140,7 @@ class TypeConversionI1ToF64Module(torch.nn.Module): @register_test_case(module_factory=lambda: TypeConversionI1ToF64Module()) def TypeConversionI1ToF64Module_basic(module, tu: TestUtils): - tensor = torch.randint(0, 2, (3, 4), dtype=torch.bool) + tensor = tu.randint(3, 4, low=0, high=2).to(torch.bool) module.forward(tensor) diff --git a/python/torch_mlir_e2e_test/test_suite/type_promotion.py b/python/torch_mlir_e2e_test/test_suite/type_promotion.py index a7a5491c5..94a5ba18f 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_promotion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_promotion.py @@ -30,8 +30,8 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module): module_factory=lambda: TypePromotionSameCategoryDifferentWidthModule()) def TypePromotionSameCategoryDifferentWidthModule_basic(module, tu: TestUtils): module.forward( - torch.randint(10, [4]).type(torch.int32), - torch.randint(10, [4])) + tu.randint(4, high=10).type(torch.int32), + tu.randint(4, high=10)) class TypePromotionDifferentCategoryModule(torch.nn.Module): @@ -51,7 +51,7 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module): @register_test_case( module_factory=lambda: TypePromotionDifferentCategoryModule()) def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, [4]), torch.randn(4)) + module.forward(tu.randint(4, high=10), torch.randn(4)) class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module): @@ -91,7 +91,7 @@ class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module): @register_test_case( module_factory=lambda: TypePromotionZeroRankHigherCategoryModule()) def TypePromotionZeroRankHigherCategoryModule_basic(module, tu: TestUtils): - module.forward(torch.randint(10, [4]), tu.rand()) + module.forward(tu.randint(4, high=10), tu.rand()) class TypePromotionAlphaWiderModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/torchscript/framework.py b/python/torch_mlir_e2e_test/torchscript/framework.py index fdaa46084..06a25d3e5 100644 --- a/python/torch_mlir_e2e_test/torchscript/framework.py +++ b/python/torch_mlir_e2e_test/torchscript/framework.py @@ -182,6 +182,9 @@ class TestUtils: def rand(self, *sizes, low=0.0, high=1.0): return torch.empty(sizes).uniform_(low, high) + def randint(self, *sizes, low=0, high=10): + return torch.randint(low, high, sizes) + def nans(self, *sizes): vals = torch.empty(sizes) vals[...] = torch.nan