[e2e] fix stack e2e test typo (#1931)

pull/1937/head
Yuanqiang Liu 2023-03-15 00:32:44 +08:00 committed by GitHub
parent 4912c3937d
commit b967469906
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -660,7 +660,7 @@ class TensorsStackModule(torch.nn.Module):
([-1, -1, -1], torch.float32, True), ([-1, -1, -1], torch.float32, True),
]) ])
def forward(self, x, y, z): def forward(self, x, y, z):
return torch.stack([x, y, z], 1) return torch.stack([x, y, z], dim=1)
@register_test_case(module_factory=lambda: TensorsStackModule()) @register_test_case(module_factory=lambda: TensorsStackModule())
@ -708,7 +708,7 @@ class TensorsStackPromoteDTypeModule(torch.nn.Module):
([-1, -1, -1], torch.int64, True), ([-1, -1, -1], torch.int64, True),
]) ])
def forward(self, x, y, z): def forward(self, x, y, z):
return torch.cat([x, y, z], dim=-2) return torch.stack([x, y, z], dim=-2)
@register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule()) @register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule())