Workaround to make CI pass

pull/835/head snapshot-20220509.441
Yi Zhang 2022-05-09 10:31:14 -04:00
parent 2e6a9c084e
commit 5a6210b35b
1 changed files with 6 additions and 1 deletions

View File

@ -140,7 +140,12 @@ class CeilFloatModule(torch.nn.Module):
])
def forward(self, lhs, rhs):
sub = float(lhs) - float(rhs)
return torch.ops.aten.ceil(float(sub))
# Cast the result to int to make e2e test baseline result to be an int.
# Without the cast, baseline result is a Tensor which is unexpected see
# https://github.com/llvm/torch-mlir/issues/842
# TODO: Investigate the root cause of baseline returning a Tensor
# without the int cast and remove the cast.
return int(torch.ops.aten.ceil(float(sub)))
@register_test_case(module_factory=lambda: CeilFloatModule())