mirror of https://github.com/llvm/torch-mlir
Fix uniform argument type
Also change the 3rd dimension to be smaller so that CI can pass without killing the process.pull/567/head
parent
2fefe68ffd
commit
6aa96f8c1e
|
@ -14,9 +14,9 @@ class UniformModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
|
@ -38,7 +38,7 @@ class UniformModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: UniformModule())
|
||||
def UniformModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 64).double(),
|
||||
tu.rand(512, 1024, 128).double(),
|
||||
tu.rand(512, 256, 1024).double())
|
||||
tu.rand(256, 512, 8).double(),
|
||||
tu.rand(512, 1024, 4).double(),
|
||||
tu.rand(512, 256, 4).double())
|
||||
|
||||
|
|
Loading…
Reference in New Issue