From cd7053dfde3946623fb229ca76ea65671d3c5d92 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 23 Sep 2021 17:50:37 -0400 Subject: [PATCH] Add runtime check --- e2e_testing/torchscript/basic.py | 8 ++-- e2e_testing/torchscript/xfail_sets.py | 6 ++- .../torchscript_e2e_test/runtime_failure.py | 42 +++++++++++++++++++ .../npcomp_torchscript/e2e_test/framework.py | 18 +++++++- .../npcomp_torchscript/e2e_test/reporting.py | 6 ++- 5 files changed, 73 insertions(+), 7 deletions(-) create mode 100644 python/test/torchscript_e2e_test/runtime_failure.py diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index eb8dddef1..0a92fdb0e 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -36,10 +36,10 @@ def MmModule_basic(module, tu: TestUtils): # are mixed with it, it fails with a mysterious-sounding low level ctypes error # that exceeds my current ability to debug. # -# @register_test_case(module_factory=lambda: MmModule()) -# def MmModule_chained(module, tu: TestUtils): -# res = module.forward(tu.rand(4, 4), tu.rand(4, 4)) -# module.forward(res, res) +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_chained(module, tu: TestUtils): + res = module.forward(tu.rand(4, 4), tu.rand(4, 4)) + module.forward(res, res) # ============================================================================== diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index c00f49605..af1c34ff0 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -18,7 +18,11 @@ _common_npcomp_lowering_xfails = { 'QuantizedMLP_basic', } -XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails +XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails | { + # The first test in the e2e test batch would fail with SystemError: null + # argument to internal routine. Might be some issue with refbackend. + 'MmModule_basic', +} XFAIL_SETS['torchscript'] = {} diff --git a/python/test/torchscript_e2e_test/runtime_failure.py b/python/test/torchscript_e2e_test/runtime_failure.py new file mode 100644 index 000000000..ed93b4a22 --- /dev/null +++ b/python/test/torchscript_e2e_test/runtime_failure.py @@ -0,0 +1,42 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils +from npcomp_torchscript.e2e_test.reporting import report_results +from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig + + +class MmModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, t): + # Input of `torch.tensor` only allows ints, floats, or bools while empty + # list defaults to tensor type + return torch.tensor([]) + + +# CHECK: FAIL - "MmModule_basic" +# CHECK: Runtime error: +# Assume that the diagnostic from the TorchScript runtime will at least contain +# the offending "return torch.tensor([])". +# CHECK: return torch.tensor([]) +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_basic(module, tu: TestUtils): + module.forward(torch.ones([])) + + +def main(): + config = TorchScriptTestConfig() + results = run_tests(GLOBAL_TEST_REGISTRY, config) + report_results(results, set(), verbose=True) + + +if __name__ == '__main__': + main() diff --git a/python/torch_support/npcomp_torchscript/e2e_test/framework.py b/python/torch_support/npcomp_torchscript/e2e_test/framework.py index 1c5a25498..3cbc6c0a2 100644 --- a/python/torch_support/npcomp_torchscript/e2e_test/framework.py +++ b/python/torch_support/npcomp_torchscript/e2e_test/framework.py @@ -242,6 +242,10 @@ class TestResult(NamedTuple): # If this is not None, then the `trace` and `golden_trace` fields are None, # and vice-versa. compilation_error: Optional[str] + # If runtime failed, a string describing the failure. + # If this is not None, then the `trace` and `golden_trace` fields are None, + # and vice-versa. + runtime_error: Optional[str] # The trace produced by the backend. trace: Optional[Trace] # The golden trace which `trace` is expected to match. @@ -303,14 +307,26 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: results.append( TestResult(unique_name=test.unique_name, compilation_error=str(e), + runtime_error=None, trace=None, golden_trace=None)) continue # TODO: Run in parallel. - trace = config.run(compiled, golden_trace) + try: + trace = config.run(compiled, golden_trace) + except Exception as e: + results.append( + TestResult(unique_name=test.unique_name, + compilation_error=None, + runtime_error=str(e), + trace=None, + golden_trace=None)) + continue + results.append( TestResult(unique_name=test.unique_name, compilation_error=None, + runtime_error=None, trace=trace, golden_trace=golden_trace)) return results diff --git a/python/torch_support/npcomp_torchscript/e2e_test/reporting.py b/python/torch_support/npcomp_torchscript/e2e_test/reporting.py index 9c0fa9369..2469cfd2d 100644 --- a/python/torch_support/npcomp_torchscript/e2e_test/reporting.py +++ b/python/torch_support/npcomp_torchscript/e2e_test/reporting.py @@ -222,7 +222,7 @@ class SingleTestReport: self.result = result self.context = context self.item_reports = None - if result.compilation_error is None: + if result.compilation_error is None and result.runtime_error is None: self.item_reports = [] for i, (item, golden_item) in enumerate( zip(result.trace, result.golden_trace)): @@ -236,6 +236,8 @@ class SingleTestReport: def failed(self): if self.result.compilation_error is not None: return True + elif self.result.runtime_error is not None: + return True return any(r.failed for r in self.item_reports) def error_str(self): @@ -244,6 +246,8 @@ class SingleTestReport: p = lambda *x: print(*x, file=f) if self.result.compilation_error is not None: return 'Compilation error: ' + self.result.compilation_error + elif self.result.runtime_error is not None: + return 'Runtime error: ' + self.result.runtime_error for report in self.item_reports: if report.failed: p(report.error_str())