mirror of https://github.com/llvm/torch-mlir
[e2e] fix stack e2e test typo (#1931)
parent
4912c3937d
commit
b967469906
|
@ -660,7 +660,7 @@ class TensorsStackModule(torch.nn.Module):
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x, y, z):
|
def forward(self, x, y, z):
|
||||||
return torch.stack([x, y, z], 1)
|
return torch.stack([x, y, z], dim=1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TensorsStackModule())
|
@register_test_case(module_factory=lambda: TensorsStackModule())
|
||||||
|
@ -708,7 +708,7 @@ class TensorsStackPromoteDTypeModule(torch.nn.Module):
|
||||||
([-1, -1, -1], torch.int64, True),
|
([-1, -1, -1], torch.int64, True),
|
||||||
])
|
])
|
||||||
def forward(self, x, y, z):
|
def forward(self, x, y, z):
|
||||||
return torch.cat([x, y, z], dim=-2)
|
return torch.stack([x, y, z], dim=-2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule())
|
@register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule())
|
||||||
|
|
Loading…
Reference in New Issue