Rewrite error reporting of e2e tests.

This now gives [much nicer output](https://gist.github.com/silvasean/f048e0f37b04542dae6469b86802bb3e).
Embarrassingly, we previously couldn't even report failures for two
different tests, and weren't able to report on compilation failures
(besides just crashing).
pull/217/head
Sean Silva 2021-05-19 17:36:00 -07:00
parent d66e8fe1f8
commit b7b7fd4959
7 changed files with 257 additions and 56 deletions

View File

@ -27,7 +27,7 @@ import vision_models
import mlp
import quantized_models
def main():
def _get_argparse():
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'],
@ -41,7 +41,16 @@ Meaning of options:
parser.add_argument('--filter', default='.*', help='''
Regular expression specifying which tests to include in this run.
''')
args = parser.parse_args()
parser.add_argument('-v', '--verbose',
default=False,
action='store_true',
help='report test results with additional detail')
return parser
def main():
args = _get_argparse().parse_args()
# Find the selected config.
if args.config == 'refbackend':
config = RefBackendTestConfig()
elif args.config == 'native_torch':
@ -49,6 +58,7 @@ Regular expression specifying which tests to include in this run.
elif args.config == 'torchscript':
config = TorchScriptTestConfig()
# Find the selected tests, and emit a diagnostic if none are found.
tests = [
test for test in GLOBAL_TEST_REGISTRY
if re.match(args.filter, test.unique_name)
@ -61,8 +71,12 @@ Regular expression specifying which tests to include in this run.
for test in GLOBAL_TEST_REGISTRY:
print(test.unique_name)
sys.exit(1)
# Run the tests.
results = run_tests(tests, config)
report_results(results)
# Report the test results.
report_results(results, args.verbose)
if __name__ == '__main__':
main()

View File

@ -20,7 +20,7 @@ compiling or TorchScript'ing).
"""
import abc
from typing import Any, Callable, List, NamedTuple, TypeVar
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
import torch
@ -175,10 +175,14 @@ class TestResult(NamedTuple):
# those reasons are stronger because we cannot simply extend this
# class.
unique_name: str # Should match Test.unique_name for corresponding test.
# If compilation failed, a string describing the failure.
# If this is not None, then the `trace` and `golden_trace` fields are None,
# and vice-versa.
compilation_error: Optional[str]
# The trace produced by the backend.
trace: Trace
trace: Optional[Trace]
# The golden trace which `trace` is expected to match.
golden_trace: Trace
golden_trace: Optional[Trace]
class _Tracer:
@ -200,6 +204,9 @@ class _Tracer:
raw_outputs = getattr(self.module, name)(*args)
if isinstance(raw_outputs, torch.Tensor):
outputs = [raw_outputs]
else:
raise Exception(
"unimplemented: non-Tensor output from function")
self.trace.append(
TraceItem(symbol=name, inputs=args, outputs=outputs))
return raw_outputs
@ -222,11 +229,20 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
for test in tests:
golden_trace = _generate_golden_trace(test)
# TODO: Precompile everything in parallel.
compiled = config.compile(test.program_factory())
try:
compiled = config.compile(test.program_factory())
except Exception as e:
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=str(e),
trace=None,
golden_trace=None))
continue
# TODO: Run in parallel.
trace = config.run(compiled, golden_trace)
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=None,
trace=trace,
golden_trace=golden_trace))
return results

View File

@ -1,60 +1,183 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Utilities for reporting the results of the test framework.
"""
from typing import List
from typing import List, Optional
import io
import textwrap
import torch
from .framework import TestResult
from .framework import TestResult, TraceItem
class _TensorStats:
class TensorSummary:
"""A summary of a tensor's contents."""
def __init__(self, tensor):
self.min = torch.min(tensor)
self.max = torch.max(tensor)
self.mean = torch.mean(tensor)
def __str__(self):
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
def _print_detailed_tensor_diff(tensor, golden_tensor):
if tensor.size() != golden_tensor.size():
print(
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
)
return
print('tensor stats : ', _TensorStats(tensor))
print('golden tensor stats: ', _TensorStats(golden_tensor))
class ErrorContext:
"""A chained list of error contexts.
def report_results(results: List[TestResult]):
"""Provide a basic error report summarizing various TestResult's."""
any_failed = False
This is useful for tracking errors across multiple levels of detail.
"""
def __init__(self, contexts: List[str]):
self.contexts = contexts
@staticmethod
def empty():
"""Create an empty error context.
Used as the top-level context.
"""
return ErrorContext([])
def chain(self, additional_context: str):
"""Chain an additional context onto the current error context.
"""
return ErrorContext(self.contexts + [additional_context])
def format_error(self, s: str):
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
class ValueReport:
"""A report for a single value processed by the program.
This is currently limited to tensors, but eventually will support
all legal TorchScript types.
"""
def __init__(self, value, golden_value, context: ErrorContext):
self.value = value
self.golden_value = golden_value
self.context = context
@property
def failed(self):
return not torch.allclose(self.value, self.golden_value)
def error_str(self):
assert self.failed
if self.value.size() != self.golden_value.size():
return self.context.format_error(
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
)
f = io.StringIO()
p = lambda *x: print(*x, file=f)
p('values mismatch')
p('got : ', TensorSummary(self.value))
p('expected: ', TensorSummary(self.golden_value))
return self.context.format_error(f.getvalue())
class TraceItemReport:
"""A report for a single trace item."""
failure_reasons: List[str]
def __init__(self, item: TraceItem, golden_item: TraceItem,
context: ErrorContext):
self.item = item
self.golden_item = golden_item
self.context = context
self.failure_reasons = []
self._evaluate_outcome()
@property
def failed(self):
return len(self.failure_reasons) != 0
def error_str(self):
return '\n'.join(self.failure_reasons)
def _evaluate_outcome(self):
if self.item.symbol != self.golden_item.symbol:
self.failure_reasons.append(
self.context.format_error(
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
))
if len(self.item.inputs) != len(self.golden_item.inputs):
self.failure_reasons.append(
self.context.format_error(
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
))
if len(self.item.outputs) != len(self.golden_item.outputs):
self.failure_reasons.append(
self.context.format_error(
f'different number of outputs: got "{len(self.item.outputs)}", expected "{len(self.golden_item.outputs)}"'
))
for i, (input, golden_input) in enumerate(
zip(self.item.inputs, self.golden_item.inputs)):
value_report = ValueReport(
input, golden_input,
self.context.chain(
f'input #{i} of call to "{self.item.symbol}"'))
if value_report.failed:
self.failure_reasons.append(value_report.error_str())
for i, (output, golden_output) in enumerate(
zip(self.item.outputs, self.golden_item.outputs)):
value_report = ValueReport(output, golden_output,
self.context.chain(f'output #{i}'))
if value_report.failed:
self.failure_reasons.append(value_report.error_str())
class SingleTestReport:
"""A report for a single test."""
item_reports: Optional[List[TraceItemReport]]
def __init__(self, result: TestResult, context: ErrorContext):
self.result = result
self.context = context
self.item_reports = None
if result.compilation_error is None:
self.item_reports = []
for i, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
self.item_reports.append(
TraceItemReport(
item, golden_item,
context.chain(
f'trace item #{i} - call to "{item.symbol}"')))
@property
def failed(self):
if self.result.compilation_error is not None:
return True
return any(r.failed for r in self.item_reports)
def error_str(self):
assert self.failed
f = io.StringIO()
p = lambda *x: print(*x, file=f)
if self.result.compilation_error is not None:
return 'compilation error' + self.result.compilation_error
for report in self.item_reports:
if report.failed:
p(report.error_str())
return f.getvalue()
def report_results(results: List[TestResult], verbose: bool = False):
"""Provide a basic error report summarizing various TestResult's.
If `verbose` is True, then provide an explanation of what failed.
"""
for result in results:
failed = False
for item_num, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
assert item.symbol == golden_item.symbol
assert len(item.inputs) == len(golden_item.inputs)
assert len(item.outputs) == len(golden_item.outputs)
for input, golden_input in zip(item.inputs, golden_item.inputs):
assert torch.allclose(input, golden_input)
for output_num, (output, golden_output) in enumerate(
zip(item.outputs, golden_item.outputs)):
# TODO: Refine error message. Things to consider:
# - Very large tensors -- don't spew, but give useful info
# - Smaller tensors / primitives -- want to show exact values
# - Machine parseable format?
if not torch.allclose(output, golden_output):
if not failed:
print('FAILURE "{}"'.format(result.unique_name))
failed = any_failed = True
print(
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
)
_print_detailed_tensor_diff(output, golden_output)
if not failed:
print('SUCCESS "{}"'.format(result.unique_name))
report = SingleTestReport(result, ErrorContext.empty())
if not report.failed:
print(f'SUCCESS - "{result.unique_name}"')
else:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'FAILURE - "{result.unique_name}"' + error_str)

View File

@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
# TODO: Refine messages.
# CHECK: SUCCESS "MmModule_basic"
# CHECK: SUCCESS - "MmModule_basic"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# CHECK: SUCCESS "MmModule_basic2"
# CHECK: SUCCESS - "MmModule_basic2"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic2(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))

View File

@ -0,0 +1,45 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | FileCheck %s
import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, t):
# Static type error that will fail TorchScript compilation -- function
# that returns tensor along one path and int along another.
if t.item() > 0:
return torch.tensor([])
else:
return 3
# CHECK: FAILURE - "MmModule_basic"
# CHECK: compilation error
# Assume that the diagnostic from the TorchScript compiler will at least contain
# the offending "return 3".
# CHECK: return 3
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(torch.ones([]))
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, verbose=True)
if __name__ == '__main__':
main()

View File

@ -26,11 +26,12 @@ class MmModule(torch.nn.Module):
# TODO: Refine error messages.
# CHECK: FAILURE "MmModule_basic"
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}}
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}}
# CHECK-NOT: ALL PASS
# CHECK: FAILURE - "MmModule_basic"
# CHECK: @ trace item #0 - call to "forward"
# CHECK: @ output #0
# CHECK: ERROR: values mismatch
# CHECK: got : Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
# CHECK: expected: Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
@ -39,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results)
report_results(results, verbose=True)
if __name__ == '__main__':

View File

@ -69,7 +69,9 @@ class VerifyBackendContractPass
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
module.emitError()
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to npcomp's backend contract. See "
"dialect conversion legality information above.";
return signalPassFailure();