Fix uniform argument type

Also change the 3rd dimension to be smaller so that CI can pass without
killing the process.
pull/567/head
Yi Zhang 2022-02-08 10:42:48 -05:00 committed by Prashant Kumar
parent 2fefe68ffd
commit 6aa96f8c1e
1 changed files with 6 additions and 6 deletions

View File

@ -14,9 +14,9 @@ class UniformModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
([-1, -1], torch.float64, True),
([-1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
([-1, -1, -1], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
@ -38,7 +38,7 @@ class UniformModule(torch.nn.Module):
@register_test_case(module_factory=lambda: UniformModule())
def UniformModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 512, 64).double(),
tu.rand(512, 1024, 128).double(),
tu.rand(512, 256, 1024).double())
tu.rand(256, 512, 8).double(),
tu.rand(512, 1024, 4).double(),
tu.rand(512, 256, 4).double())