mirror of https://github.com/llvm/torch-mlir
Add runtime check
parent
c9cc4cb2e9
commit
cd7053dfde
|
@ -36,10 +36,10 @@ def MmModule_basic(module, tu: TestUtils):
|
|||
# are mixed with it, it fails with a mysterious-sounding low level ctypes error
|
||||
# that exceeds my current ability to debug.
|
||||
#
|
||||
# @register_test_case(module_factory=lambda: MmModule())
|
||||
# def MmModule_chained(module, tu: TestUtils):
|
||||
# res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
# module.forward(res, res)
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_chained(module, tu: TestUtils):
|
||||
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
module.forward(res, res)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -18,7 +18,11 @@ _common_npcomp_lowering_xfails = {
|
|||
'QuantizedMLP_basic',
|
||||
}
|
||||
|
||||
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
|
||||
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails | {
|
||||
# The first test in the e2e test batch would fail with SystemError: null
|
||||
# argument to internal routine. Might be some issue with refbackend.
|
||||
'MmModule_basic',
|
||||
}
|
||||
|
||||
XFAIL_SETS['torchscript'] = {}
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
|
||||
from npcomp_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from npcomp_torchscript.e2e_test.reporting import report_results
|
||||
from npcomp_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from npcomp_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, t):
|
||||
# Input of `torch.tensor` only allows ints, floats, or bools while empty
|
||||
# list defaults to tensor type
|
||||
return torch.tensor([])
|
||||
|
||||
|
||||
# CHECK: FAIL - "MmModule_basic"
|
||||
# CHECK: Runtime error:
|
||||
# Assume that the diagnostic from the TorchScript runtime will at least contain
|
||||
# the offending "return torch.tensor([])".
|
||||
# CHECK: return torch.tensor([])
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones([]))
|
||||
|
||||
|
||||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results, set(), verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -242,6 +242,10 @@ class TestResult(NamedTuple):
|
|||
# If this is not None, then the `trace` and `golden_trace` fields are None,
|
||||
# and vice-versa.
|
||||
compilation_error: Optional[str]
|
||||
# If runtime failed, a string describing the failure.
|
||||
# If this is not None, then the `trace` and `golden_trace` fields are None,
|
||||
# and vice-versa.
|
||||
runtime_error: Optional[str]
|
||||
# The trace produced by the backend.
|
||||
trace: Optional[Trace]
|
||||
# The golden trace which `trace` is expected to match.
|
||||
|
@ -303,14 +307,26 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
|||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=str(e),
|
||||
runtime_error=None,
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
continue
|
||||
# TODO: Run in parallel.
|
||||
trace = config.run(compiled, golden_trace)
|
||||
try:
|
||||
trace = config.run(compiled, golden_trace)
|
||||
except Exception as e:
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error=str(e),
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
continue
|
||||
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
runtime_error=None,
|
||||
trace=trace,
|
||||
golden_trace=golden_trace))
|
||||
return results
|
||||
|
|
|
@ -222,7 +222,7 @@ class SingleTestReport:
|
|||
self.result = result
|
||||
self.context = context
|
||||
self.item_reports = None
|
||||
if result.compilation_error is None:
|
||||
if result.compilation_error is None and result.runtime_error is None:
|
||||
self.item_reports = []
|
||||
for i, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
|
@ -236,6 +236,8 @@ class SingleTestReport:
|
|||
def failed(self):
|
||||
if self.result.compilation_error is not None:
|
||||
return True
|
||||
elif self.result.runtime_error is not None:
|
||||
return True
|
||||
return any(r.failed for r in self.item_reports)
|
||||
|
||||
def error_str(self):
|
||||
|
@ -244,6 +246,8 @@ class SingleTestReport:
|
|||
p = lambda *x: print(*x, file=f)
|
||||
if self.result.compilation_error is not None:
|
||||
return 'Compilation error: ' + self.result.compilation_error
|
||||
elif self.result.runtime_error is not None:
|
||||
return 'Runtime error: ' + self.result.runtime_error
|
||||
for report in self.item_reports:
|
||||
if report.failed:
|
||||
p(report.error_str())
|
||||
|
|
Loading…
Reference in New Issue