Rewrite error reporting of e2e tests.

This now gives [much nicer output](https://gist.github.com/silvasean/f048e0f37b04542dae6469b86802bb3e).
Embarrassingly, we previously couldn't even report failures for two
different tests, and weren't able to report on compilation failures
(besides just crashing).
pull/217/head
Sean Silva 2021-05-19 17:36:00 -07:00
parent d66e8fe1f8
commit b7b7fd4959
7 changed files with 257 additions and 56 deletions

View File

@ -27,7 +27,7 @@ import vision_models
import mlp import mlp
import quantized_models import quantized_models
def main(): def _get_argparse():
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config', parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'], choices=['native_torch', 'torchscript', 'refbackend'],
@ -41,7 +41,16 @@ Meaning of options:
parser.add_argument('--filter', default='.*', help=''' parser.add_argument('--filter', default='.*', help='''
Regular expression specifying which tests to include in this run. Regular expression specifying which tests to include in this run.
''') ''')
args = parser.parse_args() parser.add_argument('-v', '--verbose',
default=False,
action='store_true',
help='report test results with additional detail')
return parser
def main():
args = _get_argparse().parse_args()
# Find the selected config.
if args.config == 'refbackend': if args.config == 'refbackend':
config = RefBackendTestConfig() config = RefBackendTestConfig()
elif args.config == 'native_torch': elif args.config == 'native_torch':
@ -49,6 +58,7 @@ Regular expression specifying which tests to include in this run.
elif args.config == 'torchscript': elif args.config == 'torchscript':
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
# Find the selected tests, and emit a diagnostic if none are found.
tests = [ tests = [
test for test in GLOBAL_TEST_REGISTRY test for test in GLOBAL_TEST_REGISTRY
if re.match(args.filter, test.unique_name) if re.match(args.filter, test.unique_name)
@ -61,8 +71,12 @@ Regular expression specifying which tests to include in this run.
for test in GLOBAL_TEST_REGISTRY: for test in GLOBAL_TEST_REGISTRY:
print(test.unique_name) print(test.unique_name)
sys.exit(1) sys.exit(1)
# Run the tests.
results = run_tests(tests, config) results = run_tests(tests, config)
report_results(results)
# Report the test results.
report_results(results, args.verbose)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -20,7 +20,7 @@ compiling or TorchScript'ing).
""" """
import abc import abc
from typing import Any, Callable, List, NamedTuple, TypeVar from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
import torch import torch
@ -175,10 +175,14 @@ class TestResult(NamedTuple):
# those reasons are stronger because we cannot simply extend this # those reasons are stronger because we cannot simply extend this
# class. # class.
unique_name: str # Should match Test.unique_name for corresponding test. unique_name: str # Should match Test.unique_name for corresponding test.
# If compilation failed, a string describing the failure.
# If this is not None, then the `trace` and `golden_trace` fields are None,
# and vice-versa.
compilation_error: Optional[str]
# The trace produced by the backend. # The trace produced by the backend.
trace: Trace trace: Optional[Trace]
# The golden trace which `trace` is expected to match. # The golden trace which `trace` is expected to match.
golden_trace: Trace golden_trace: Optional[Trace]
class _Tracer: class _Tracer:
@ -200,6 +204,9 @@ class _Tracer:
raw_outputs = getattr(self.module, name)(*args) raw_outputs = getattr(self.module, name)(*args)
if isinstance(raw_outputs, torch.Tensor): if isinstance(raw_outputs, torch.Tensor):
outputs = [raw_outputs] outputs = [raw_outputs]
else:
raise Exception(
"unimplemented: non-Tensor output from function")
self.trace.append( self.trace.append(
TraceItem(symbol=name, inputs=args, outputs=outputs)) TraceItem(symbol=name, inputs=args, outputs=outputs))
return raw_outputs return raw_outputs
@ -222,11 +229,20 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
for test in tests: for test in tests:
golden_trace = _generate_golden_trace(test) golden_trace = _generate_golden_trace(test)
# TODO: Precompile everything in parallel. # TODO: Precompile everything in parallel.
compiled = config.compile(test.program_factory()) try:
compiled = config.compile(test.program_factory())
except Exception as e:
results.append(
TestResult(unique_name=test.unique_name,
compilation_error=str(e),
trace=None,
golden_trace=None))
continue
# TODO: Run in parallel. # TODO: Run in parallel.
trace = config.run(compiled, golden_trace) trace = config.run(compiled, golden_trace)
results.append( results.append(
TestResult(unique_name=test.unique_name, TestResult(unique_name=test.unique_name,
compilation_error=None,
trace=trace, trace=trace,
golden_trace=golden_trace)) golden_trace=golden_trace))
return results return results

View File

@ -1,60 +1,183 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information. # See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
""" """
Utilities for reporting the results of the test framework. Utilities for reporting the results of the test framework.
""" """
from typing import List from typing import List, Optional
import io
import textwrap
import torch import torch
from .framework import TestResult from .framework import TestResult, TraceItem
class _TensorStats:
class TensorSummary:
"""A summary of a tensor's contents."""
def __init__(self, tensor): def __init__(self, tensor):
self.min = torch.min(tensor) self.min = torch.min(tensor)
self.max = torch.max(tensor) self.max = torch.max(tensor)
self.mean = torch.mean(tensor) self.mean = torch.mean(tensor)
def __str__(self): def __str__(self):
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}' return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
def _print_detailed_tensor_diff(tensor, golden_tensor): class ErrorContext:
if tensor.size() != golden_tensor.size(): """A chained list of error contexts.
print(
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
)
return
print('tensor stats : ', _TensorStats(tensor))
print('golden tensor stats: ', _TensorStats(golden_tensor))
def report_results(results: List[TestResult]): This is useful for tracking errors across multiple levels of detail.
"""Provide a basic error report summarizing various TestResult's.""" """
any_failed = False def __init__(self, contexts: List[str]):
self.contexts = contexts
@staticmethod
def empty():
"""Create an empty error context.
Used as the top-level context.
"""
return ErrorContext([])
def chain(self, additional_context: str):
"""Chain an additional context onto the current error context.
"""
return ErrorContext(self.contexts + [additional_context])
def format_error(self, s: str):
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
class ValueReport:
"""A report for a single value processed by the program.
This is currently limited to tensors, but eventually will support
all legal TorchScript types.
"""
def __init__(self, value, golden_value, context: ErrorContext):
self.value = value
self.golden_value = golden_value
self.context = context
@property
def failed(self):
return not torch.allclose(self.value, self.golden_value)
def error_str(self):
assert self.failed
if self.value.size() != self.golden_value.size():
return self.context.format_error(
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
)
f = io.StringIO()
p = lambda *x: print(*x, file=f)
p('values mismatch')
p('got : ', TensorSummary(self.value))
p('expected: ', TensorSummary(self.golden_value))
return self.context.format_error(f.getvalue())
class TraceItemReport:
"""A report for a single trace item."""
failure_reasons: List[str]
def __init__(self, item: TraceItem, golden_item: TraceItem,
context: ErrorContext):
self.item = item
self.golden_item = golden_item
self.context = context
self.failure_reasons = []
self._evaluate_outcome()
@property
def failed(self):
return len(self.failure_reasons) != 0
def error_str(self):
return '\n'.join(self.failure_reasons)
def _evaluate_outcome(self):
if self.item.symbol != self.golden_item.symbol:
self.failure_reasons.append(
self.context.format_error(
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
))
if len(self.item.inputs) != len(self.golden_item.inputs):
self.failure_reasons.append(
self.context.format_error(
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
))
if len(self.item.outputs) != len(self.golden_item.outputs):
self.failure_reasons.append(
self.context.format_error(
f'different number of outputs: got "{len(self.item.outputs)}", expected "{len(self.golden_item.outputs)}"'
))
for i, (input, golden_input) in enumerate(
zip(self.item.inputs, self.golden_item.inputs)):
value_report = ValueReport(
input, golden_input,
self.context.chain(
f'input #{i} of call to "{self.item.symbol}"'))
if value_report.failed:
self.failure_reasons.append(value_report.error_str())
for i, (output, golden_output) in enumerate(
zip(self.item.outputs, self.golden_item.outputs)):
value_report = ValueReport(output, golden_output,
self.context.chain(f'output #{i}'))
if value_report.failed:
self.failure_reasons.append(value_report.error_str())
class SingleTestReport:
"""A report for a single test."""
item_reports: Optional[List[TraceItemReport]]
def __init__(self, result: TestResult, context: ErrorContext):
self.result = result
self.context = context
self.item_reports = None
if result.compilation_error is None:
self.item_reports = []
for i, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
self.item_reports.append(
TraceItemReport(
item, golden_item,
context.chain(
f'trace item #{i} - call to "{item.symbol}"')))
@property
def failed(self):
if self.result.compilation_error is not None:
return True
return any(r.failed for r in self.item_reports)
def error_str(self):
assert self.failed
f = io.StringIO()
p = lambda *x: print(*x, file=f)
if self.result.compilation_error is not None:
return 'compilation error' + self.result.compilation_error
for report in self.item_reports:
if report.failed:
p(report.error_str())
return f.getvalue()
def report_results(results: List[TestResult], verbose: bool = False):
"""Provide a basic error report summarizing various TestResult's.
If `verbose` is True, then provide an explanation of what failed.
"""
for result in results: for result in results:
failed = False report = SingleTestReport(result, ErrorContext.empty())
for item_num, (item, golden_item) in enumerate( if not report.failed:
zip(result.trace, result.golden_trace)): print(f'SUCCESS - "{result.unique_name}"')
assert item.symbol == golden_item.symbol else:
assert len(item.inputs) == len(golden_item.inputs) error_str = ''
assert len(item.outputs) == len(golden_item.outputs) if verbose:
for input, golden_input in zip(item.inputs, golden_item.inputs): error_str = '\n' + textwrap.indent(report.error_str(), ' ')
assert torch.allclose(input, golden_input) print(f'FAILURE - "{result.unique_name}"' + error_str)
for output_num, (output, golden_output) in enumerate(
zip(item.outputs, golden_item.outputs)):
# TODO: Refine error message. Things to consider:
# - Very large tensors -- don't spew, but give useful info
# - Smaller tensors / primitives -- want to show exact values
# - Machine parseable format?
if not torch.allclose(output, golden_output):
if not failed:
print('FAILURE "{}"'.format(result.unique_name))
failed = any_failed = True
print(
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
)
_print_detailed_tensor_diff(output, golden_output)
if not failed:
print('SUCCESS "{}"'.format(result.unique_name))

View File

@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
# TODO: Refine messages. # TODO: Refine messages.
# CHECK: SUCCESS "MmModule_basic" # CHECK: SUCCESS - "MmModule_basic"
@register_test_case(module_factory=lambda: MmModule()) @register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils): def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))
# CHECK: SUCCESS "MmModule_basic2" # CHECK: SUCCESS - "MmModule_basic2"
@register_test_case(module_factory=lambda: MmModule()) @register_test_case(module_factory=lambda: MmModule())
def MmModule_basic2(module, tu: TestUtils): def MmModule_basic2(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))

View File

@ -0,0 +1,45 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | FileCheck %s
import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, t):
# Static type error that will fail TorchScript compilation -- function
# that returns tensor along one path and int along another.
if t.item() > 0:
return torch.tensor([])
else:
return 3
# CHECK: FAILURE - "MmModule_basic"
# CHECK: compilation error
# Assume that the diagnostic from the TorchScript compiler will at least contain
# the offending "return 3".
# CHECK: return 3
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(torch.ones([]))
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, verbose=True)
if __name__ == '__main__':
main()

View File

@ -26,11 +26,12 @@ class MmModule(torch.nn.Module):
# TODO: Refine error messages. # TODO: Refine error messages.
# CHECK: FAILURE "MmModule_basic" # CHECK: FAILURE - "MmModule_basic"
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward" # CHECK: @ trace item #0 - call to "forward"
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}} # CHECK: @ output #0
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}} # CHECK: ERROR: values mismatch
# CHECK-NOT: ALL PASS # CHECK: got : Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
# CHECK: expected: Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
@register_test_case(module_factory=lambda: MmModule()) @register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils): def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))
@ -39,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
def main(): def main():
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config) results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results) report_results(results, verbose=True)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -69,7 +69,9 @@ class VerifyBackendContractPass
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) { if (failed(applyFullConversion(module, target, std::move(patterns)))) {
module.emitError() // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to npcomp's backend contract. See " << "Module does not conform to npcomp's backend contract. See "
"dialect conversion legality information above."; "dialect conversion legality information above.";
return signalPassFailure(); return signalPassFailure();