diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 47d7c0cb3..8326d11c8 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -95,7 +95,6 @@ TORCHDYNAMO_XFAIL_SET = { "StdDimNoneDimModule_basic", "StdUnbiasedModule_basic", "UniformModule_basic", - "UniformStaticModule_basic", # %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor "ElementwiseWhereScalarModule_basic", "ElementwiseWhereScalarOtherModule_basic", @@ -750,7 +749,6 @@ LTC_XFAIL_SET = { "TensorToInt_basic", "TensorsConcatModule_basic", "UniformModule_basic", - "UniformStaticModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index e9a8898f8..2fc1444ff 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -44,44 +44,6 @@ def UniformModule_basic(module, tu: TestUtils): # ============================================================================== -class UniformStaticModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([256, 512, 8], torch.float64, True), - ([512, 1024, 4], torch.float64, True), - ([512, 256, 4], torch.float64, True), - ]) - def forward(self, x, y, z): - a = torch.ops.aten.uniform_(x, 1.0, 10.0) - b = torch.ops.aten.uniform_(y, -20.0, -5.0) - c = torch.ops.aten.uniform_(z, -15.0, 3.0) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - torch.flatten(torch.std(c)) - ]) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - torch.flatten(torch.mean(c)) - ]) - return std, mean - - -@register_test_case(module_factory=lambda: UniformStaticModule()) -def UniformStaticModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(256, 512, 8).double(), - tu.rand(512, 1024, 4).double(), - tu.rand(512, 256, 4).double()) - -# ============================================================================== - class BernoulliModule(torch.nn.Module): def __init__(self): super().__init__() @@ -90,32 +52,18 @@ class BernoulliModule(torch.nn.Module): @annotate_args([ None, ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), ]) - def forward(self, x, y, z): + def forward(self, x): a = torch.bernoulli(x) - b = torch.bernoulli(y) - c = torch.bernoulli(z) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - torch.flatten(torch.mean(c)) - ]) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - torch.flatten(torch.std(c)) - ]) + mean = torch.mean(a) + std = torch.std(a) return mean, std @register_test_case(module_factory=lambda: BernoulliModule()) def BernoulliModule_basic(module, tu: TestUtils): module.forward( - tu.rand(512, 1024, 8).double(), - tu.rand(1024, 2048, 4).double(), - tu.rand(1024, 256, 4).double()) + tu.rand(512, 512, 16).double()) # ============================================================================== @@ -166,21 +114,17 @@ class BernoulliFloatModule(torch.nn.Module): None, ([-1, -1, -1], torch.float64, True), ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), ]) - def forward(self, x, y, z): + def forward(self, x, y): a = torch.ops.aten.bernoulli_(x, 0.4) b = torch.ops.aten.bernoulli_(y, 0.7) - c = torch.ops.aten.bernoulli_(z, 0.5) mean = torch.cat([ torch.flatten(torch.mean(a)), torch.flatten(torch.mean(b)), - torch.flatten(torch.mean(c)) ]) std = torch.cat([ torch.flatten(torch.std(a)), torch.flatten(torch.std(b)), - torch.flatten(torch.std(c)) ]) return mean, std @@ -188,9 +132,8 @@ class BernoulliFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: BernoulliFloatModule()) def BernoulliFloatModule_basic(module, tu: TestUtils): module.forward( - tu.rand(512, 1024, 8).double(), - tu.rand(1024, 2048, 4).double(), - tu.rand(1024, 512, 4).double()) + tu.rand(512, 512, 10).double(), + tu.rand(512, 512, 10).double()) # ============================================================================== @@ -203,37 +146,19 @@ class BernoulliTensorModule(torch.nn.Module): None, ([-1, -1, -1], torch.float64, True), ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), ]) - def forward(self, x, px, y, py, z, pz): + def forward(self, x, px): a = torch.ops.aten.bernoulli_(x, px) - b = torch.ops.aten.bernoulli_(y, py) - c = torch.ops.aten.bernoulli_(z, pz) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - torch.flatten(torch.mean(c)) - ]) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - torch.flatten(torch.std(c)) - ]) + mean = torch.mean(a) + std = torch.std(a) return mean, std @register_test_case(module_factory=lambda: BernoulliTensorModule()) def BernoulliTensorModule_basic(module, tu: TestUtils): module.forward( - tu.rand(1024, 1024, 16).double(), - tu.rand(1024, 1024, 16).double(), - tu.rand(1024, 2048, 8).double(), - tu.rand(1024, 2048, 8).double(), - tu.rand(1024, 512, 8).double(), - tu.rand(1024, 512, 8).double()) + tu.rand(512, 512, 2).double(), + tu.rand(512, 512, 2).double()) # ==============================================================================