2020-11-25 11:02:50 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# This file is licensed under a pytorch-style license
|
|
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
|
|
|
|
import sys
|
|
|
|
import textwrap
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
INDENT = " "
|
|
|
|
|
|
|
|
|
|
|
|
def _indent(value):
|
|
|
|
return textwrap.indent(str(value), INDENT)
|
|
|
|
|
|
|
|
|
|
|
|
def compare_outputs(torch_func, jit_func, *args):
|
2021-04-07 06:58:37 +08:00
|
|
|
print('-' * 80)
|
2020-11-25 11:02:50 +08:00
|
|
|
|
|
|
|
print(f"Input args:\n{_indent(args)}", file=sys.stderr)
|
|
|
|
result = torch_func(*args)
|
|
|
|
jit_result = jit_func(*args)
|
|
|
|
|
2021-04-07 06:58:37 +08:00
|
|
|
np.testing.assert_allclose(result.numpy(), jit_result, rtol=1e-05, atol=1e-08)
|
2020-11-25 11:02:50 +08:00
|
|
|
|
|
|
|
# Only print these if the test passes, as np.testing will print them if it
|
|
|
|
# fails.
|
2021-04-07 06:58:37 +08:00
|
|
|
print("SUCCESS", file=sys.stderr)
|
2020-11-25 11:02:50 +08:00
|
|
|
print(f"PyTorch Result:\n{_indent(result.numpy())}", file=sys.stderr)
|
|
|
|
print(f"JIT Result:\n{_indent(jit_result)}", file=sys.stderr)
|