Reduce memory usage of e2e tests by reducing input sizes (#1653)

There are a few e2e tests that take several very large tensors as
input, which leads to the e2e test suite leaking too much
memory. Running things locally resulted in a total memory usage of
12.5 GB when running the suite sequentially on the refbackend.

Many of the tests that take large tensors don't actually need
such large tensors to pass, and some that take several large tensors
as input are just doing the same thing multiple times. This commit
reduces the size of some of the tensors and removes repetitive parts
of tests to reduce the memory usage to a total of 3 GB.
pull/1658/head
Ramiro Leal-Cavazos 2022-11-29 10:03:36 -08:00 committed by GitHub
parent 4d49c44967
commit a8cbfff95b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 89 deletions

View File

@ -95,7 +95,6 @@ TORCHDYNAMO_XFAIL_SET = {
"StdDimNoneDimModule_basic",
"StdUnbiasedModule_basic",
"UniformModule_basic",
"UniformStaticModule_basic",
# %1 = torch.operator "aten.scalar_tensor"(%float8.000000e00, %int6, %int0, %cpu, %none) : (!torch.float, !torch.int, !torch.int, !torch.Device, !torch.none) -> !torch.tensor
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
@ -750,7 +749,6 @@ LTC_XFAIL_SET = {
"TensorToInt_basic",
"TensorsConcatModule_basic",
"UniformModule_basic",
"UniformStaticModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"AtenEmbeddingBagSumExample_basic",

View File

@ -44,44 +44,6 @@ def UniformModule_basic(module, tu: TestUtils):
# ==============================================================================
class UniformStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([256, 512, 8], torch.float64, True),
([512, 1024, 4], torch.float64, True),
([512, 256, 4], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
return std, mean
@register_test_case(module_factory=lambda: UniformStaticModule())
def UniformStaticModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 512, 8).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())
# ==============================================================================
class BernoulliModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -90,32 +52,18 @@ class BernoulliModule(torch.nn.Module):
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, x, y, z):
def forward(self, x):
a = torch.bernoulli(x)
b = torch.bernoulli(y)
c = torch.bernoulli(z)
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
mean = torch.mean(a)
std = torch.std(a)
return mean, std
@register_test_case(module_factory=lambda: BernoulliModule())
def BernoulliModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(512, 1024, 8).double(),
tu.rand(1024, 2048, 4).double(),
tu.rand(1024, 256, 4).double())
tu.rand(512, 512, 16).double())
# ==============================================================================
@ -166,21 +114,17 @@ class BernoulliFloatModule(torch.nn.Module):
None,
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, x, y, z):
def forward(self, x, y):
a = torch.ops.aten.bernoulli_(x, 0.4)
b = torch.ops.aten.bernoulli_(y, 0.7)
c = torch.ops.aten.bernoulli_(z, 0.5)
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
return mean, std
@ -188,9 +132,8 @@ class BernoulliFloatModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BernoulliFloatModule())
def BernoulliFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(512, 1024, 8).double(),
tu.rand(1024, 2048, 4).double(),
tu.rand(1024, 512, 4).double())
tu.rand(512, 512, 10).double(),
tu.rand(512, 512, 10).double())
# ==============================================================================
@ -203,37 +146,19 @@ class BernoulliTensorModule(torch.nn.Module):
None,
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, x, px, y, py, z, pz):
def forward(self, x, px):
a = torch.ops.aten.bernoulli_(x, px)
b = torch.ops.aten.bernoulli_(y, py)
c = torch.ops.aten.bernoulli_(z, pz)
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
mean = torch.mean(a)
std = torch.std(a)
return mean, std
@register_test_case(module_factory=lambda: BernoulliTensorModule())
def BernoulliTensorModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(1024, 1024, 16).double(),
tu.rand(1024, 1024, 16).double(),
tu.rand(1024, 2048, 8).double(),
tu.rand(1024, 2048, 8).double(),
tu.rand(1024, 512, 8).double(),
tu.rand(1024, 512, 8).double())
tu.rand(512, 512, 2).double(),
tu.rand(512, 512, 2).double())
# ==============================================================================