mirror of https://github.com/llvm/torch-mlir
Replace `torch.rand` and `torch.randn` in e2e tests with `tu.rand` (#1890)
Random tensors used in e2e tests should be created using the `TestUtils` object passed to the registered test case to ensure that the compiled module and the golden trace receive the same tensors as input. This commit changes all the cases of `torch.rand` and `torch.randn` to use the `TestUtils` instead.pull/1886/head
parent
eb74014dd8
commit
52dbb160fc
|
@ -728,7 +728,7 @@ class AddSizeIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AddSizeIntModule())
|
||||
def AddSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3))
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -753,7 +753,7 @@ class AddSizeIntNegDimModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AddSizeIntNegDimModule())
|
||||
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3))
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -904,7 +904,7 @@ class SoftmaxIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntModule())
|
||||
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -926,7 +926,7 @@ class _SoftmaxModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _SoftmaxModule())
|
||||
def _SoftmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -950,7 +950,7 @@ class SoftmaxIntNegDimModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntNegDimModule())
|
||||
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -974,7 +974,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module())
|
||||
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
module.forward(tu.rand(3, 2, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -996,7 +996,7 @@ class _LogSoftmaxModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _LogSoftmaxModule())
|
||||
def _LogSoftmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1265,7 +1265,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: LogSoftmaxIntModule())
|
||||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
module.forward(tu.rand(3, 2, 4).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -3031,7 +3031,7 @@ class AtenEmbeddingBagSumExample(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample())
|
||||
def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils):
|
||||
weight = torch.rand(100, 10)
|
||||
weight = tu.rand(100, 10)
|
||||
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
|
||||
module.forward(weight, indices, offsets)
|
||||
|
@ -3053,7 +3053,7 @@ class Aten_EmbeddingBagExample(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Aten_EmbeddingBagExample())
|
||||
def Aten_EmbeddingBagExample_basic(module, tu: TestUtils):
|
||||
weight = torch.rand(100, 10)
|
||||
weight = tu.rand(100, 10)
|
||||
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
|
||||
module.forward(weight, indices, offsets)
|
||||
|
@ -3128,7 +3128,7 @@ class AtenToDeviceModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AtenToDeviceModule())
|
||||
def AtenToDeviceModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(2, 4))
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class TensorToFloatZeroRank(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TensorToFloatZeroRank())
|
||||
def TensorToFloatZeroRank_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand((), dtype=torch.float64))
|
||||
module.forward(tu.rand().to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -83,7 +83,7 @@ class TensorToFloat(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TensorToFloat())
|
||||
def TensorToFloat_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand((1, 1), dtype=torch.float64))
|
||||
module.forward(tu.rand(1, 1).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -1095,7 +1095,7 @@ class ZeroFloat32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ZeroFloat32Module())
|
||||
def ZeroFloat32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 2))
|
||||
module.forward(tu.rand(3, 2))
|
||||
|
||||
|
||||
class ZeroInt32Module(torch.nn.Module):
|
||||
|
|
|
@ -165,7 +165,7 @@ class Convolution2DModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Convolution2DModule())
|
||||
def Convolution2DModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
|
||||
class Convolution2DStaticModule(torch.nn.Module):
|
||||
|
@ -191,7 +191,7 @@ class Convolution2DStaticModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Convolution2DStaticModule())
|
||||
def Convolution2DStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class Convolution2DStridedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -216,7 +216,7 @@ class Convolution2DStridedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Convolution2DStridedModule())
|
||||
def Convolution2DStridedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _Convolution2DAllFalseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -245,7 +245,7 @@ class _Convolution2DAllFalseModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DAllFalseModule())
|
||||
def _Convolution2DAllFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _Convolution2DBenchmarkModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -274,7 +274,7 @@ class _Convolution2DBenchmarkModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DBenchmarkModule())
|
||||
def _Convolution2DBenchmarkModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _Convolution2DDeterministicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -303,7 +303,7 @@ class _Convolution2DDeterministicModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DDeterministicModule())
|
||||
def _Convolution2DDeterministicModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _Convolution2DCudnnModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -332,7 +332,7 @@ class _Convolution2DCudnnModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DCudnnModule())
|
||||
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _Convolution2DTF32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -361,7 +361,7 @@ class _Convolution2DTF32Module(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _Convolution2DTF32Module())
|
||||
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -389,7 +389,7 @@ class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
|
||||
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -417,7 +417,7 @@ class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
|
||||
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -445,7 +445,7 @@ class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
|
||||
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -473,7 +473,7 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
|
||||
def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class ConvolutionModule2DGroups(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -498,7 +498,7 @@ class ConvolutionModule2DGroups(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DGroups())
|
||||
def ConvolutionModule2DGroups_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3))
|
||||
module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -527,7 +527,7 @@ class ConvolutionModule2DTranspose(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DTranspose())
|
||||
def ConvolutionModule2DTranspose_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2))
|
||||
module.forward(tu.rand(3, 3, 4, 4), tu.rand(3, 3, 2, 2))
|
||||
|
||||
class ConvolutionModule2DTransposeStrided(torch.nn.Module):
|
||||
|
||||
|
@ -554,7 +554,7 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided())
|
||||
def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
|
||||
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
|
||||
|
||||
class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module):
|
||||
|
||||
|
@ -581,7 +581,7 @@ class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic())
|
||||
def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
|
||||
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
|
||||
|
||||
|
||||
class Conv_Transpose2dModule(torch.nn.Module):
|
||||
|
@ -608,7 +608,7 @@ class Conv_Transpose2dModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Conv_Transpose2dModule())
|
||||
def Conv_Transpose2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2))
|
||||
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
|
||||
|
||||
|
||||
class UpSampleNearest2d(torch.nn.Module):
|
||||
|
|
|
@ -152,7 +152,7 @@ class ElementwiseAtenWhereSelfModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule())
|
||||
def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), torch.rand(1, 12, 5, 5), torch.rand(()))
|
||||
module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), tu.rand(1, 12, 5, 5), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -1500,7 +1500,7 @@ class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float())
|
||||
def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(10, 3))
|
||||
module.forward(tu.rand(10, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -2313,7 +2313,7 @@ class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule())
|
||||
def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(2, 3, 3, 5), torch.rand(2, 3, 3, 5))
|
||||
module.forward(tu.rand(2, 3, 3, 5), tu.rand(2, 3, 3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -2729,7 +2729,7 @@ class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
|
||||
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
||||
|
@ -2748,7 +2748,7 @@ class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
|
||||
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||
module.forward(tu.rand(3, 2, 4).to(torch.float64))
|
||||
|
||||
|
||||
class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
||||
|
@ -2767,7 +2767,7 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
|
||||
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).to(torch.float64))
|
||||
module.forward(tu.rand(3, 2, 4).to(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -89,7 +89,7 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature())
|
||||
def HBC_basic(module, tu: TestUtils):
|
||||
logits = torch.rand(NUM_LOGITS, dtype=torch.float)
|
||||
logits = tu.rand(NUM_LOGITS)
|
||||
segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int)
|
||||
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
|
||||
segment_offsets: Tensor = torch.cat(
|
||||
|
|
|
@ -28,7 +28,7 @@ class IndexSelectSingleIdxModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
|
||||
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2]))
|
||||
|
||||
|
||||
class IndexSelectTwoIdxModule(torch.nn.Module):
|
||||
|
@ -47,7 +47,7 @@ class IndexSelectTwoIdxModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
|
||||
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([2, 4]))
|
||||
|
||||
|
||||
class IndexSelectWholeDimensionModule(torch.nn.Module):
|
||||
|
@ -66,7 +66,7 @@ class IndexSelectWholeDimensionModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
|
||||
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
||||
|
||||
|
||||
class IndexSelectWholeTensorModule(torch.nn.Module):
|
||||
|
@ -85,7 +85,7 @@ class IndexSelectWholeTensorModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
|
||||
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))
|
||||
module.forward(tu.rand(3), torch.tensor([0, 1, 2]))
|
||||
|
||||
|
||||
class IndexSelectDynamicModule(torch.nn.Module):
|
||||
|
@ -104,7 +104,7 @@ class IndexSelectDynamicModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
|
||||
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 4]))
|
||||
|
||||
|
||||
class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
||||
|
@ -123,7 +123,7 @@ class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
|
||||
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([0, 2]))
|
||||
|
||||
|
||||
class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
||||
|
@ -142,4 +142,4 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
|
||||
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))
|
||||
module.forward(tu.rand(4, 5, 6), torch.tensor([1, 2]))
|
||||
|
|
|
@ -189,7 +189,7 @@ class NllLossModule_backwardWeight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NllLossModule_backwardWeight())
|
||||
def NllLossModuleBackwardWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
tu.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
|
||||
|
@ -279,7 +279,7 @@ class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight())
|
||||
def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
tu.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backwardSum(torch.nn.Module):
|
||||
|
@ -338,7 +338,7 @@ class NllLossModule_backwardSumWeight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight())
|
||||
def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]),
|
||||
torch.rand(4), torch.tensor(3.))
|
||||
tu.rand(4), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1D(torch.nn.Module):
|
||||
|
@ -397,7 +397,7 @@ class NllLossModule_backward1DWeight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight())
|
||||
def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
tu.rand(3), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DMean(torch.nn.Module):
|
||||
|
@ -456,7 +456,7 @@ class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight())
|
||||
def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
tu.rand(3), torch.tensor(3.))
|
||||
|
||||
|
||||
class NllLossModule_backward1DSum(torch.nn.Module):
|
||||
|
@ -515,4 +515,4 @@ class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight())
|
||||
def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]),
|
||||
torch.rand(3), torch.tensor(3.))
|
||||
tu.rand(3), torch.tensor(3.))
|
||||
|
|
|
@ -586,7 +586,7 @@ class ReduceL1NormModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceL1NormModule())
|
||||
def ReduceL1NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -622,7 +622,7 @@ class ReduceL2NormModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceL2NormModule())
|
||||
def ReduceL2NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -640,7 +640,7 @@ class ReduceLN3NormModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceLN3NormModule())
|
||||
def ReduceLN3NormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -658,7 +658,7 @@ class ReduceL3NormAllDimsModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule())
|
||||
def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -676,7 +676,7 @@ class ReduceL3NormKeepDimModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule())
|
||||
def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -733,7 +733,7 @@ class ReduceFrobeniusNormModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceFrobeniusNormModule())
|
||||
def ReduceFrobeniusNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
class ReduceFrobeniusNormKeepDimModule(torch.nn.Module):
|
||||
|
@ -750,7 +750,7 @@ class ReduceFrobeniusNormKeepDimModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule())
|
||||
def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(3, 4, 5))
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class SubFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SubFloatModule())
|
||||
def SubFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||
module.forward(tu.rand().double(), tu.rand().double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -146,7 +146,7 @@ class DivFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: DivFloatModule())
|
||||
def DivFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||
module.forward(tu.rand().double(), tu.rand().double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -175,7 +175,7 @@ class CeilFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: CeilFloatModule())
|
||||
def CeilFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(()).double(), torch.rand(()).double())
|
||||
module.forward(tu.rand().double(), tu.rand().double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -121,7 +121,7 @@ class GeFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GeFloatModule())
|
||||
def GeFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), torch.randn(()).double())
|
||||
module.forward(tu.rand().double(), tu.rand().double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -144,7 +144,7 @@ class GeFloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GeFloatIntModule())
|
||||
def GeFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.rand().double(), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -167,7 +167,7 @@ class NeFloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: NeFloatIntModule())
|
||||
def NeFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.rand().double(), tu.randint(low=-100, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -190,4 +190,4 @@ class GtFloatIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: GtFloatIntModule())
|
||||
def GtFloatIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100))
|
||||
module.forward(tu.rand().double(), tu.randint(low=-100, high=100))
|
||||
|
|
|
@ -384,7 +384,7 @@ class SelectScatterModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SelectScatterModule())
|
||||
def SelectScattertModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(6, 8, 5), torch.rand(8, 5))
|
||||
module.forward(tu.rand(6, 8, 5), tu.rand(8, 5))
|
||||
|
||||
class SelectScatterStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -402,7 +402,7 @@ class SelectScatterStaticModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SelectScatterStaticModule())
|
||||
def SelectScattertStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(6, 8, 5), torch.rand(6, 5))
|
||||
module.forward(tu.rand(6, 8, 5), tu.rand(6, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ class Threshold1dFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
||||
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4))
|
||||
module.forward(tu.rand(4))
|
||||
|
||||
|
||||
class Threshold2dFloatModule(torch.nn.Module):
|
||||
|
@ -117,7 +117,7 @@ class Threshold2dFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
||||
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5))
|
||||
module.forward(tu.rand(4, 5))
|
||||
|
||||
|
||||
class Threshold3dFloatModule(torch.nn.Module):
|
||||
|
@ -135,7 +135,7 @@ class Threshold3dFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
||||
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6))
|
||||
module.forward(tu.rand(4, 5, 6))
|
||||
|
||||
|
||||
class ThresholdBackward1dIntModule(torch.nn.Module):
|
||||
|
@ -211,7 +211,7 @@ class ThresholdBackward1dFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
||||
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), torch.randn(4))
|
||||
module.forward(tu.rand(4), tu.rand(4))
|
||||
|
||||
|
||||
class ThresholdBackward2dFloatModule(torch.nn.Module):
|
||||
|
@ -230,7 +230,7 @@ class ThresholdBackward2dFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
||||
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5), torch.randn(4, 5))
|
||||
module.forward(tu.rand(4, 5), tu.rand(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dFloatModule(torch.nn.Module):
|
||||
|
@ -249,7 +249,7 @@ class ThresholdBackward3dFloatModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
||||
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
|
||||
module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6))
|
||||
|
||||
|
||||
class ThresholdBackward1dMixedModule(torch.nn.Module):
|
||||
|
@ -268,7 +268,7 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), tu.randint(4, high=10))
|
||||
module.forward(tu.rand(4), tu.randint(4, high=10))
|
||||
|
||||
|
||||
class ThresholdBackward2dMixedModule(torch.nn.Module):
|
||||
|
@ -287,7 +287,7 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5))
|
||||
module.forward(tu.randint(4, 5, high=20), tu.rand(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dMixedModule(torch.nn.Module):
|
||||
|
@ -306,4 +306,4 @@ class ThresholdBackward3dMixedModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
||||
module.forward(tu.rand(4, 5, 6), tu.randint(4, 5, 6, high=10))
|
||||
|
|
|
@ -51,7 +51,7 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
|||
@register_test_case(
|
||||
module_factory=lambda: TypePromotionDifferentCategoryModule())
|
||||
def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, high=10), torch.randn(4))
|
||||
module.forward(tu.randint(4, high=10), tu.rand(4))
|
||||
|
||||
|
||||
class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue