mirror of https://github.com/llvm/torch-mlir
Rewrite error reporting of e2e tests.
This now gives [much nicer output](https://gist.github.com/silvasean/f048e0f37b04542dae6469b86802bb3e). Embarrassingly, we previously couldn't even report failures for two different tests, and weren't able to report on compilation failures (besides just crashing).pull/217/head
parent
d66e8fe1f8
commit
b7b7fd4959
|
@ -27,7 +27,7 @@ import vision_models
|
|||
import mlp
|
||||
import quantized_models
|
||||
|
||||
def main():
|
||||
def _get_argparse():
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('--config',
|
||||
choices=['native_torch', 'torchscript', 'refbackend'],
|
||||
|
@ -41,7 +41,16 @@ Meaning of options:
|
|||
parser.add_argument('--filter', default='.*', help='''
|
||||
Regular expression specifying which tests to include in this run.
|
||||
''')
|
||||
args = parser.parse_args()
|
||||
parser.add_argument('-v', '--verbose',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='report test results with additional detail')
|
||||
return parser
|
||||
|
||||
def main():
|
||||
args = _get_argparse().parse_args()
|
||||
|
||||
# Find the selected config.
|
||||
if args.config == 'refbackend':
|
||||
config = RefBackendTestConfig()
|
||||
elif args.config == 'native_torch':
|
||||
|
@ -49,6 +58,7 @@ Regular expression specifying which tests to include in this run.
|
|||
elif args.config == 'torchscript':
|
||||
config = TorchScriptTestConfig()
|
||||
|
||||
# Find the selected tests, and emit a diagnostic if none are found.
|
||||
tests = [
|
||||
test for test in GLOBAL_TEST_REGISTRY
|
||||
if re.match(args.filter, test.unique_name)
|
||||
|
@ -61,8 +71,12 @@ Regular expression specifying which tests to include in this run.
|
|||
for test in GLOBAL_TEST_REGISTRY:
|
||||
print(test.unique_name)
|
||||
sys.exit(1)
|
||||
|
||||
# Run the tests.
|
||||
results = run_tests(tests, config)
|
||||
report_results(results)
|
||||
|
||||
# Report the test results.
|
||||
report_results(results, args.verbose)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -20,7 +20,7 @@ compiling or TorchScript'ing).
|
|||
"""
|
||||
|
||||
import abc
|
||||
from typing import Any, Callable, List, NamedTuple, TypeVar
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -175,10 +175,14 @@ class TestResult(NamedTuple):
|
|||
# those reasons are stronger because we cannot simply extend this
|
||||
# class.
|
||||
unique_name: str # Should match Test.unique_name for corresponding test.
|
||||
# If compilation failed, a string describing the failure.
|
||||
# If this is not None, then the `trace` and `golden_trace` fields are None,
|
||||
# and vice-versa.
|
||||
compilation_error: Optional[str]
|
||||
# The trace produced by the backend.
|
||||
trace: Trace
|
||||
trace: Optional[Trace]
|
||||
# The golden trace which `trace` is expected to match.
|
||||
golden_trace: Trace
|
||||
golden_trace: Optional[Trace]
|
||||
|
||||
|
||||
class _Tracer:
|
||||
|
@ -200,6 +204,9 @@ class _Tracer:
|
|||
raw_outputs = getattr(self.module, name)(*args)
|
||||
if isinstance(raw_outputs, torch.Tensor):
|
||||
outputs = [raw_outputs]
|
||||
else:
|
||||
raise Exception(
|
||||
"unimplemented: non-Tensor output from function")
|
||||
self.trace.append(
|
||||
TraceItem(symbol=name, inputs=args, outputs=outputs))
|
||||
return raw_outputs
|
||||
|
@ -222,11 +229,20 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
|||
for test in tests:
|
||||
golden_trace = _generate_golden_trace(test)
|
||||
# TODO: Precompile everything in parallel.
|
||||
compiled = config.compile(test.program_factory())
|
||||
try:
|
||||
compiled = config.compile(test.program_factory())
|
||||
except Exception as e:
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=str(e),
|
||||
trace=None,
|
||||
golden_trace=None))
|
||||
continue
|
||||
# TODO: Run in parallel.
|
||||
trace = config.run(compiled, golden_trace)
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
compilation_error=None,
|
||||
trace=trace,
|
||||
golden_trace=golden_trace))
|
||||
return results
|
||||
|
|
|
@ -1,60 +1,183 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""
|
||||
Utilities for reporting the results of the test framework.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import io
|
||||
import textwrap
|
||||
|
||||
import torch
|
||||
|
||||
from .framework import TestResult
|
||||
from .framework import TestResult, TraceItem
|
||||
|
||||
class _TensorStats:
|
||||
|
||||
class TensorSummary:
|
||||
"""A summary of a tensor's contents."""
|
||||
def __init__(self, tensor):
|
||||
self.min = torch.min(tensor)
|
||||
self.max = torch.max(tensor)
|
||||
self.mean = torch.mean(tensor)
|
||||
|
||||
def __str__(self):
|
||||
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
||||
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
||||
|
||||
|
||||
def _print_detailed_tensor_diff(tensor, golden_tensor):
|
||||
if tensor.size() != golden_tensor.size():
|
||||
print(
|
||||
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
||||
)
|
||||
return
|
||||
print('tensor stats : ', _TensorStats(tensor))
|
||||
print('golden tensor stats: ', _TensorStats(golden_tensor))
|
||||
class ErrorContext:
|
||||
"""A chained list of error contexts.
|
||||
|
||||
def report_results(results: List[TestResult]):
|
||||
"""Provide a basic error report summarizing various TestResult's."""
|
||||
any_failed = False
|
||||
This is useful for tracking errors across multiple levels of detail.
|
||||
"""
|
||||
def __init__(self, contexts: List[str]):
|
||||
self.contexts = contexts
|
||||
|
||||
@staticmethod
|
||||
def empty():
|
||||
"""Create an empty error context.
|
||||
|
||||
Used as the top-level context.
|
||||
"""
|
||||
return ErrorContext([])
|
||||
|
||||
def chain(self, additional_context: str):
|
||||
"""Chain an additional context onto the current error context.
|
||||
"""
|
||||
return ErrorContext(self.contexts + [additional_context])
|
||||
|
||||
def format_error(self, s: str):
|
||||
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
|
||||
|
||||
|
||||
class ValueReport:
|
||||
"""A report for a single value processed by the program.
|
||||
|
||||
This is currently limited to tensors, but eventually will support
|
||||
all legal TorchScript types.
|
||||
"""
|
||||
def __init__(self, value, golden_value, context: ErrorContext):
|
||||
self.value = value
|
||||
self.golden_value = golden_value
|
||||
self.context = context
|
||||
|
||||
@property
|
||||
def failed(self):
|
||||
return not torch.allclose(self.value, self.golden_value)
|
||||
|
||||
def error_str(self):
|
||||
assert self.failed
|
||||
if self.value.size() != self.golden_value.size():
|
||||
return self.context.format_error(
|
||||
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
||||
)
|
||||
f = io.StringIO()
|
||||
p = lambda *x: print(*x, file=f)
|
||||
p('values mismatch')
|
||||
p('got : ', TensorSummary(self.value))
|
||||
p('expected: ', TensorSummary(self.golden_value))
|
||||
return self.context.format_error(f.getvalue())
|
||||
|
||||
|
||||
class TraceItemReport:
|
||||
"""A report for a single trace item."""
|
||||
failure_reasons: List[str]
|
||||
|
||||
def __init__(self, item: TraceItem, golden_item: TraceItem,
|
||||
context: ErrorContext):
|
||||
self.item = item
|
||||
self.golden_item = golden_item
|
||||
self.context = context
|
||||
self.failure_reasons = []
|
||||
self._evaluate_outcome()
|
||||
|
||||
@property
|
||||
def failed(self):
|
||||
return len(self.failure_reasons) != 0
|
||||
|
||||
def error_str(self):
|
||||
return '\n'.join(self.failure_reasons)
|
||||
|
||||
def _evaluate_outcome(self):
|
||||
if self.item.symbol != self.golden_item.symbol:
|
||||
self.failure_reasons.append(
|
||||
self.context.format_error(
|
||||
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
|
||||
))
|
||||
if len(self.item.inputs) != len(self.golden_item.inputs):
|
||||
self.failure_reasons.append(
|
||||
self.context.format_error(
|
||||
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
|
||||
))
|
||||
if len(self.item.outputs) != len(self.golden_item.outputs):
|
||||
self.failure_reasons.append(
|
||||
self.context.format_error(
|
||||
f'different number of outputs: got "{len(self.item.outputs)}", expected "{len(self.golden_item.outputs)}"'
|
||||
))
|
||||
for i, (input, golden_input) in enumerate(
|
||||
zip(self.item.inputs, self.golden_item.inputs)):
|
||||
value_report = ValueReport(
|
||||
input, golden_input,
|
||||
self.context.chain(
|
||||
f'input #{i} of call to "{self.item.symbol}"'))
|
||||
if value_report.failed:
|
||||
self.failure_reasons.append(value_report.error_str())
|
||||
for i, (output, golden_output) in enumerate(
|
||||
zip(self.item.outputs, self.golden_item.outputs)):
|
||||
value_report = ValueReport(output, golden_output,
|
||||
self.context.chain(f'output #{i}'))
|
||||
if value_report.failed:
|
||||
self.failure_reasons.append(value_report.error_str())
|
||||
|
||||
|
||||
class SingleTestReport:
|
||||
"""A report for a single test."""
|
||||
item_reports: Optional[List[TraceItemReport]]
|
||||
|
||||
def __init__(self, result: TestResult, context: ErrorContext):
|
||||
self.result = result
|
||||
self.context = context
|
||||
self.item_reports = None
|
||||
if result.compilation_error is None:
|
||||
self.item_reports = []
|
||||
for i, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
self.item_reports.append(
|
||||
TraceItemReport(
|
||||
item, golden_item,
|
||||
context.chain(
|
||||
f'trace item #{i} - call to "{item.symbol}"')))
|
||||
|
||||
@property
|
||||
def failed(self):
|
||||
if self.result.compilation_error is not None:
|
||||
return True
|
||||
return any(r.failed for r in self.item_reports)
|
||||
|
||||
def error_str(self):
|
||||
assert self.failed
|
||||
f = io.StringIO()
|
||||
p = lambda *x: print(*x, file=f)
|
||||
if self.result.compilation_error is not None:
|
||||
return 'compilation error' + self.result.compilation_error
|
||||
for report in self.item_reports:
|
||||
if report.failed:
|
||||
p(report.error_str())
|
||||
return f.getvalue()
|
||||
|
||||
|
||||
def report_results(results: List[TestResult], verbose: bool = False):
|
||||
"""Provide a basic error report summarizing various TestResult's.
|
||||
|
||||
If `verbose` is True, then provide an explanation of what failed.
|
||||
"""
|
||||
for result in results:
|
||||
failed = False
|
||||
for item_num, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
assert item.symbol == golden_item.symbol
|
||||
assert len(item.inputs) == len(golden_item.inputs)
|
||||
assert len(item.outputs) == len(golden_item.outputs)
|
||||
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
||||
assert torch.allclose(input, golden_input)
|
||||
for output_num, (output, golden_output) in enumerate(
|
||||
zip(item.outputs, golden_item.outputs)):
|
||||
# TODO: Refine error message. Things to consider:
|
||||
# - Very large tensors -- don't spew, but give useful info
|
||||
# - Smaller tensors / primitives -- want to show exact values
|
||||
# - Machine parseable format?
|
||||
if not torch.allclose(output, golden_output):
|
||||
if not failed:
|
||||
print('FAILURE "{}"'.format(result.unique_name))
|
||||
failed = any_failed = True
|
||||
print(
|
||||
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
||||
)
|
||||
_print_detailed_tensor_diff(output, golden_output)
|
||||
if not failed:
|
||||
print('SUCCESS "{}"'.format(result.unique_name))
|
||||
report = SingleTestReport(result, ErrorContext.empty())
|
||||
if not report.failed:
|
||||
print(f'SUCCESS - "{result.unique_name}"')
|
||||
else:
|
||||
error_str = ''
|
||||
if verbose:
|
||||
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||
print(f'FAILURE - "{result.unique_name}"' + error_str)
|
||||
|
|
|
@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
|
|||
|
||||
|
||||
# TODO: Refine messages.
|
||||
# CHECK: SUCCESS "MmModule_basic"
|
||||
# CHECK: SUCCESS - "MmModule_basic"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
# CHECK: SUCCESS "MmModule_basic2"
|
||||
# CHECK: SUCCESS - "MmModule_basic2"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic2(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, t):
|
||||
# Static type error that will fail TorchScript compilation -- function
|
||||
# that returns tensor along one path and int along another.
|
||||
if t.item() > 0:
|
||||
return torch.tensor([])
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
# CHECK: FAILURE - "MmModule_basic"
|
||||
# CHECK: compilation error
|
||||
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
||||
# the offending "return 3".
|
||||
# CHECK: return 3
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.ones([]))
|
||||
|
||||
|
||||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results, verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -26,11 +26,12 @@ class MmModule(torch.nn.Module):
|
|||
|
||||
|
||||
# TODO: Refine error messages.
|
||||
# CHECK: FAILURE "MmModule_basic"
|
||||
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
||||
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}}
|
||||
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}}
|
||||
# CHECK-NOT: ALL PASS
|
||||
# CHECK: FAILURE - "MmModule_basic"
|
||||
# CHECK: @ trace item #0 - call to "forward"
|
||||
# CHECK: @ output #0
|
||||
# CHECK: ERROR: values mismatch
|
||||
# CHECK: got : Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
|
||||
# CHECK: expected: Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
@ -39,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
|
|||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results)
|
||||
report_results(results, verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -69,7 +69,9 @@ class VerifyBackendContractPass
|
|||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
module.emitError()
|
||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||
// doesn't unnecessarily spew out the entire module.
|
||||
emitError(module.getLoc())
|
||||
<< "Module does not conform to npcomp's backend contract. See "
|
||||
"dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
|
|
Loading…
Reference in New Issue