mirror of https://github.com/llvm/torch-mlir
[e2e] fix stack e2e test typo (#1931)
parent
4912c3937d
commit
b967469906
|
@ -660,7 +660,7 @@ class TensorsStackModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.stack([x, y, z], 1)
|
||||
return torch.stack([x, y, z], dim=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorsStackModule())
|
||||
|
@ -708,7 +708,7 @@ class TensorsStackPromoteDTypeModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.cat([x, y, z], dim=-2)
|
||||
return torch.stack([x, y, z], dim=-2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule())
|
||||
|
|
Loading…
Reference in New Issue