2022-01-12 09:59:42 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|
|
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|
|
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|
|
|
|
|
|
|
# ==============================================================================
|
2022-02-12 11:42:18 +08:00
|
|
|
|
2022-01-12 09:59:42 +08:00
|
|
|
class UniformModule(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
2022-02-08 23:42:48 +08:00
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
2022-01-12 09:59:42 +08:00
|
|
|
])
|
|
|
|
def forward(self, x, y, z):
|
|
|
|
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
|
|
|
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
|
|
|
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
|
|
|
std = torch.cat([
|
|
|
|
torch.flatten(torch.std(a)),
|
|
|
|
torch.flatten(torch.std(b)),
|
|
|
|
torch.flatten(torch.std(c))
|
|
|
|
])
|
|
|
|
mean = torch.cat([
|
|
|
|
torch.flatten(torch.mean(a)),
|
|
|
|
torch.flatten(torch.mean(b)),
|
|
|
|
torch.flatten(torch.mean(c))
|
|
|
|
])
|
|
|
|
return std, mean
|
|
|
|
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: UniformModule())
|
|
|
|
def UniformModule_basic(module, tu: TestUtils):
|
|
|
|
module.forward(
|
2022-02-08 23:42:48 +08:00
|
|
|
tu.rand(256, 512, 8).double(),
|
|
|
|
tu.rand(512, 1024, 4).double(),
|
|
|
|
tu.rand(512, 256, 4).double())
|
2022-01-12 09:59:42 +08:00
|
|
|
|
2022-02-04 19:43:25 +08:00
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-12 11:42:18 +08:00
|
|
|
class UniformStaticModule(torch.nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([256, 512, 8], torch.float64, True),
|
|
|
|
([512, 1024, 4], torch.float64, True),
|
|
|
|
([512, 256, 4], torch.float64, True),
|
|
|
|
])
|
|
|
|
def forward(self, x, y, z):
|
|
|
|
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
|
|
|
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
|
|
|
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
|
|
|
std = torch.cat([
|
|
|
|
torch.flatten(torch.std(a)),
|
|
|
|
torch.flatten(torch.std(b)),
|
|
|
|
torch.flatten(torch.std(c))
|
|
|
|
])
|
|
|
|
mean = torch.cat([
|
|
|
|
torch.flatten(torch.mean(a)),
|
|
|
|
torch.flatten(torch.mean(b)),
|
|
|
|
torch.flatten(torch.mean(c))
|
|
|
|
])
|
|
|
|
return std, mean
|
|
|
|
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: UniformStaticModule())
|
|
|
|
def UniformStaticModule_basic(module, tu: TestUtils):
|
|
|
|
module.forward(
|
|
|
|
tu.rand(256, 512, 8).double(),
|
|
|
|
tu.rand(512, 1024, 4).double(),
|
|
|
|
tu.rand(512, 256, 4).double())
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
2022-02-04 19:43:25 +08:00
|
|
|
class BernoulliModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
2022-02-26 00:35:04 +08:00
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
2022-02-04 19:43:25 +08:00
|
|
|
])
|
2022-02-26 00:35:04 +08:00
|
|
|
def forward(self, x, y, z):
|
|
|
|
a = torch.bernoulli(x)
|
|
|
|
b = torch.bernoulli(y)
|
|
|
|
c = torch.bernoulli(z)
|
|
|
|
mean = torch.cat([
|
|
|
|
torch.flatten(torch.mean(a)),
|
|
|
|
torch.flatten(torch.mean(b)),
|
|
|
|
torch.flatten(torch.mean(c))
|
|
|
|
])
|
|
|
|
std = torch.cat([
|
|
|
|
torch.flatten(torch.std(a)),
|
|
|
|
torch.flatten(torch.std(b)),
|
|
|
|
torch.flatten(torch.std(c))
|
|
|
|
])
|
2022-02-04 19:43:25 +08:00
|
|
|
return mean, std
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
2022-02-04 19:43:25 +08:00
|
|
|
@register_test_case(module_factory=lambda: BernoulliModule())
|
|
|
|
def BernoulliModule_basic(module, tu: TestUtils):
|
2022-02-26 00:35:04 +08:00
|
|
|
module.forward(
|
2022-03-03 07:02:18 +08:00
|
|
|
tu.rand(512, 1024, 8).double(),
|
|
|
|
tu.rand(1024, 2048, 4).double(),
|
|
|
|
tu.rand(1024, 256, 4).double())
|
2022-02-26 00:35:04 +08:00
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
class BernoulliZerosModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float64, True),
|
|
|
|
])
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.bernoulli(x)
|
|
|
|
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: BernoulliZerosModule())
|
|
|
|
def BernoulliZerosModule_basic(module, tu: TestUtils):
|
|
|
|
module.forward(torch.zeros(4, 8).double())
|
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
class BernoulliOnesModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1], torch.float64, True),
|
|
|
|
])
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.bernoulli(x)
|
|
|
|
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: BernoulliOnesModule())
|
|
|
|
def BernoulliOnesModule_basic(module, tu: TestUtils):
|
|
|
|
module.forward(torch.ones(4, 8).double())
|
|
|
|
|
|
|
|
# ==============================================================================
|
2022-02-04 19:43:25 +08:00
|
|
|
|
|
|
|
class BernoulliFloatModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
])
|
2022-02-26 00:35:04 +08:00
|
|
|
def forward(self, x, y, z):
|
|
|
|
a = torch.ops.aten.bernoulli_(x, 0.4)
|
|
|
|
b = torch.ops.aten.bernoulli_(y, 0.7)
|
|
|
|
c = torch.ops.aten.bernoulli_(z, 0.5)
|
2022-02-04 19:43:25 +08:00
|
|
|
mean = torch.cat([
|
2022-02-26 00:35:04 +08:00
|
|
|
torch.flatten(torch.mean(a)),
|
|
|
|
torch.flatten(torch.mean(b)),
|
|
|
|
torch.flatten(torch.mean(c))
|
2022-02-04 19:43:25 +08:00
|
|
|
])
|
|
|
|
std = torch.cat([
|
2022-02-26 00:35:04 +08:00
|
|
|
torch.flatten(torch.std(a)),
|
|
|
|
torch.flatten(torch.std(b)),
|
|
|
|
torch.flatten(torch.std(c))
|
2022-02-04 19:43:25 +08:00
|
|
|
])
|
|
|
|
return mean, std
|
|
|
|
|
2022-02-26 00:35:04 +08:00
|
|
|
|
2022-02-04 19:43:25 +08:00
|
|
|
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
|
|
|
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
|
|
|
module.forward(
|
2022-03-03 07:02:18 +08:00
|
|
|
tu.rand(512, 1024, 8).double(),
|
|
|
|
tu.rand(1024, 2048, 4).double(),
|
|
|
|
tu.rand(1024, 512, 4).double())
|
2022-02-26 00:35:04 +08:00
|
|
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
class BernoulliTensorModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@export
|
|
|
|
@annotate_args([
|
|
|
|
None,
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
([-1, -1, -1], torch.float64, True),
|
|
|
|
])
|
|
|
|
def forward(self, x, px, y, py, z, pz):
|
|
|
|
a = torch.ops.aten.bernoulli_(x, px)
|
|
|
|
b = torch.ops.aten.bernoulli_(y, py)
|
|
|
|
c = torch.ops.aten.bernoulli_(z, pz)
|
|
|
|
mean = torch.cat([
|
|
|
|
torch.flatten(torch.mean(a)),
|
|
|
|
torch.flatten(torch.mean(b)),
|
|
|
|
torch.flatten(torch.mean(c))
|
|
|
|
])
|
|
|
|
std = torch.cat([
|
|
|
|
torch.flatten(torch.std(a)),
|
|
|
|
torch.flatten(torch.std(b)),
|
|
|
|
torch.flatten(torch.std(c))
|
|
|
|
])
|
|
|
|
return mean, std
|
|
|
|
|
|
|
|
|
|
|
|
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
|
|
|
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
|
|
|
module.forward(
|
2022-03-03 07:02:18 +08:00
|
|
|
tu.rand(1024, 1024, 16).double(),
|
|
|
|
tu.rand(1024, 1024, 16).double(),
|
|
|
|
tu.rand(1024, 2048, 8).double(),
|
|
|
|
tu.rand(1024, 2048, 8).double(),
|
|
|
|
tu.rand(1024, 512, 8).double(),
|
|
|
|
tu.rand(1024, 512, 8).double())
|