Add runtime check

pull/328/head
Yi Zhang 2021-09-23 17:50:37 -04:00
parent c9cc4cb2e9
commit cd7053dfde
5 changed files with 73 additions and 7 deletions

View File

@ -36,10 +36,10 @@ def MmModule_basic(module, tu: TestUtils):
# are mixed with it, it fails with a mysterious-sounding low level ctypes error
# that exceeds my current ability to debug.
#
# @register_test_case(module_factory=lambda: MmModule())
# def MmModule_chained(module, tu: TestUtils):
# res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
# module.forward(res, res)
@register_test_case(module_factory=lambda: MmModule())
def MmModule_chained(module, tu: TestUtils):
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
module.forward(res, res)
# ==============================================================================

View File

@ -18,7 +18,11 @@ _common_npcomp_lowering_xfails = {
'QuantizedMLP_basic',
}
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails | {
# The first test in the e2e test batch would fail with SystemError: null
# argument to internal routine. Might be some issue with refbackend.
'MmModule_basic',
}
XFAIL_SETS['torchscript'] = {}

View File

@ -0,0 +1,42 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# RUN: %PYTHON %s | FileCheck %s
import torch
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
from npcomp_torchscript.e2e_test.reporting import report_results
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, t):
# Input of `torch.tensor` only allows ints, floats, or bools while empty
# list defaults to tensor type
return torch.tensor([])
# CHECK: FAIL - "MmModule_basic"
# CHECK: Runtime error:
# Assume that the diagnostic from the TorchScript runtime will at least contain
# the offending "return torch.tensor([])".
# CHECK: return torch.tensor([])
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(torch.ones([]))
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, set(), verbose=True)
if __name__ == '__main__':
main()

View File

@ -242,6 +242,10 @@ class TestResult(NamedTuple):
# If this is not None, then the `trace` and `golden_trace` fields are None,
# and vice-versa.
compilation_error: Optional[str]
# If runtime failed, a string describing the failure.
# If this is not None, then the `trace` and `golden_trace` fields are None,
# and vice-versa.
runtime_error: Optional[str]
# The trace produced by the backend.
trace: Optional[Trace]
# The golden trace which `trace` is expected to match.
@ -303,14 +307,26 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=str(e),
runtime_error=None,
trace=None,
golden_trace=None))
continue
# TODO: Run in parallel.
try:
trace = config.run(compiled, golden_trace)
except Exception as e:
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=None,
runtime_error=str(e),
trace=None,
golden_trace=None))
continue
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=None,
runtime_error=None,
trace=trace,
golden_trace=golden_trace))
return results

View File

@ -222,7 +222,7 @@ class SingleTestReport:
self.result = result
self.context = context
self.item_reports = None
if result.compilation_error is None:
if result.compilation_error is None and result.runtime_error is None:
self.item_reports = []
for i, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
@ -236,6 +236,8 @@ class SingleTestReport:
def failed(self):
if self.result.compilation_error is not None:
return True
elif self.result.runtime_error is not None:
return True
return any(r.failed for r in self.item_reports)
def error_str(self):
@ -244,6 +246,8 @@ class SingleTestReport:
p = lambda *x: print(*x, file=f)
if self.result.compilation_error is not None:
return 'Compilation error: ' + self.result.compilation_error
elif self.result.runtime_error is not None:
return 'Runtime error: ' + self.result.runtime_error
for report in self.item_reports:
if report.failed:
p(report.error_str())