mirror of https://github.com/llvm/torch-mlir
Reduce memory usage of e2e tests by reducing input sizes (#1653)
There are a few e2e tests that take several very large tensors as input, which leads to the e2e test suite leaking too much memory. Running things locally resulted in a total memory usage of 12.5 GB when running the suite sequentially on the refbackend. Many of the tests that take large tensors don't actually need such large tensors to pass, and some that take several large tensors as input are just doing the same thing multiple times. This commit reduces the size of some of the tensors and removes repetitive parts of tests to reduce the memory usage to a total of 3 GB.pull/1658/head
parent
4d49c44967
commit
a8cbfff95b
|
@ -95,7 +95,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"StdDimNoneDimModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformStaticModule_basic",
|
||||
# %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseWhereScalarOtherModule_basic",
|
||||
|
@ -750,7 +749,6 @@ LTC_XFAIL_SET = {
|
|||
"TensorToInt_basic",
|
||||
"TensorsConcatModule_basic",
|
||||
"UniformModule_basic",
|
||||
"UniformStaticModule_basic",
|
||||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
|
|
|
@ -44,44 +44,6 @@ def UniformModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class UniformStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([256, 512, 8], torch.float64, True),
|
||||
([512, 1024, 4], torch.float64, True),
|
||||
([512, 256, 4], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
return std, mean
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UniformStaticModule())
|
||||
def UniformStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BernoulliModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -90,32 +52,18 @@ class BernoulliModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
def forward(self, x):
|
||||
a = torch.bernoulli(x)
|
||||
b = torch.bernoulli(y)
|
||||
c = torch.bernoulli(z)
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.mean(a)
|
||||
std = torch.std(a)
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliModule())
|
||||
def BernoulliModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 1024, 8).double(),
|
||||
tu.rand(1024, 2048, 4).double(),
|
||||
tu.rand(1024, 256, 4).double())
|
||||
tu.rand(512, 512, 16).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -166,21 +114,17 @@ class BernoulliFloatModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
def forward(self, x, y):
|
||||
a = torch.ops.aten.bernoulli_(x, 0.4)
|
||||
b = torch.ops.aten.bernoulli_(y, 0.7)
|
||||
c = torch.ops.aten.bernoulli_(z, 0.5)
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
return mean, std
|
||||
|
||||
|
@ -188,9 +132,8 @@ class BernoulliFloatModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BernoulliFloatModule())
|
||||
def BernoulliFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(512, 1024, 8).double(),
|
||||
tu.rand(1024, 2048, 4).double(),
|
||||
tu.rand(1024, 512, 4).double())
|
||||
tu.rand(512, 512, 10).double(),
|
||||
tu.rand(512, 512, 10).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -203,37 +146,19 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, px, y, py, z, pz):
|
||||
def forward(self, x, px):
|
||||
a = torch.ops.aten.bernoulli_(x, px)
|
||||
b = torch.ops.aten.bernoulli_(y, py)
|
||||
c = torch.ops.aten.bernoulli_(z, pz)
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.mean(a)
|
||||
std = torch.std(a)
|
||||
return mean, std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BernoulliTensorModule())
|
||||
def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(1024, 1024, 16).double(),
|
||||
tu.rand(1024, 1024, 16).double(),
|
||||
tu.rand(1024, 2048, 8).double(),
|
||||
tu.rand(1024, 2048, 8).double(),
|
||||
tu.rand(1024, 512, 8).double(),
|
||||
tu.rand(1024, 512, 8).double())
|
||||
tu.rand(512, 512, 2).double(),
|
||||
tu.rand(512, 512, 2).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue