mirror of https://github.com/llvm/torch-mlir
338 lines
13 KiB
Python
338 lines
13 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
"""
|
|
Utilities for reporting the results of the test framework.
|
|
"""
|
|
|
|
from typing import Any, List, Optional, Set
|
|
|
|
import collections
|
|
import io
|
|
import textwrap
|
|
|
|
import torch
|
|
|
|
from .framework import TestResult, TraceItem
|
|
|
|
|
|
class TensorSummary:
|
|
"""A summary of a tensor's contents."""
|
|
def __init__(self, tensor):
|
|
self.min = torch.min(tensor.type(torch.float64))
|
|
self.max = torch.max(tensor.type(torch.float64))
|
|
self.mean = torch.mean(tensor.type(torch.float64))
|
|
self.shape = list(tensor.shape)
|
|
self.dtype = tensor.dtype
|
|
|
|
def __str__(self):
|
|
return f'Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
|
|
|
|
|
|
class ErrorContext:
|
|
"""A chained list of error contexts.
|
|
|
|
This is useful for tracking errors across multiple levels of detail.
|
|
"""
|
|
def __init__(self, contexts: List[str]):
|
|
self.contexts = contexts
|
|
|
|
@staticmethod
|
|
def empty():
|
|
"""Create an empty error context.
|
|
|
|
Used as the top-level context.
|
|
"""
|
|
return ErrorContext([])
|
|
|
|
def chain(self, additional_context: str):
|
|
"""Chain an additional context onto the current error context.
|
|
"""
|
|
return ErrorContext(self.contexts + [additional_context])
|
|
|
|
def format_error(self, s: str):
|
|
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
|
|
|
|
|
|
class ValueReport:
|
|
"""A report for a single value processed by the program.
|
|
"""
|
|
def __init__(self, value, golden_value, context: ErrorContext):
|
|
self.value = value
|
|
self.golden_value = golden_value
|
|
self.context = context
|
|
self.failure_reasons = []
|
|
self._evaluate_outcome()
|
|
|
|
@property
|
|
def failed(self):
|
|
return len(self.failure_reasons) != 0
|
|
|
|
def error_str(self):
|
|
return '\n'.join(self.failure_reasons)
|
|
|
|
def _evaluate_outcome(self):
|
|
value, golden = self.value, self.golden_value
|
|
if isinstance(golden, float):
|
|
if not isinstance(value, float):
|
|
return self._record_mismatch_type_failure('float', value)
|
|
if abs(value - golden) / golden > 1e-4:
|
|
return self._record_failure(
|
|
f'value ({value!r}) is not close to golden value ({golden!r})'
|
|
)
|
|
return
|
|
if isinstance(golden, int):
|
|
if not isinstance(value, int):
|
|
return self._record_mismatch_type_failure('int', value)
|
|
if value != golden:
|
|
return self._record_failure(
|
|
f'value ({value!r}) is not equal to golden value ({golden!r})'
|
|
)
|
|
return
|
|
if isinstance(golden, str):
|
|
if not isinstance(value, str):
|
|
return self._record_mismatch_type_failure('str', value)
|
|
if value != golden:
|
|
return self._record_failure(
|
|
f'value ({value!r}) is not equal to golden value ({golden!r})'
|
|
)
|
|
return
|
|
if isinstance(golden, tuple):
|
|
if not isinstance(value, tuple):
|
|
return self._record_mismatch_type_failure('tuple', value)
|
|
if len(value) != len(golden):
|
|
return self._record_failure(
|
|
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
|
|
)
|
|
reports = [
|
|
ValueReport(v, g, self.context.chain(f'tuple element {i}'))
|
|
for i, (v, g) in enumerate(zip(value, golden))
|
|
]
|
|
for report in reports:
|
|
if report.failed:
|
|
self.failure_reasons.extend(report.failure_reasons)
|
|
return
|
|
if isinstance(golden, list):
|
|
if not isinstance(value, list):
|
|
return self._record_mismatch_type_failure('list', value)
|
|
if len(value) != len(golden):
|
|
return self._record_failure(
|
|
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
|
|
)
|
|
reports = [
|
|
ValueReport(v, g, self.context.chain(f'list element {i}'))
|
|
for i, (v, g) in enumerate(zip(value, golden))
|
|
]
|
|
for report in reports:
|
|
if report.failed:
|
|
self.failure_reasons.extend(report.failure_reasons)
|
|
return
|
|
if isinstance(golden, dict):
|
|
if not isinstance(value, dict):
|
|
return self._record_mismatch_type_failure('dict', value)
|
|
gkeys = list(sorted(golden.keys()))
|
|
vkeys = list(sorted(value.keys()))
|
|
if gkeys != vkeys:
|
|
return self._record_failure(
|
|
f'dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})'
|
|
)
|
|
reports = [
|
|
ValueReport(value[k], golden[k],
|
|
self.context.chain(f'dict element at key {k!r}'))
|
|
for k in gkeys
|
|
]
|
|
for report in reports:
|
|
if report.failed:
|
|
self.failure_reasons.extend(report.failure_reasons)
|
|
return
|
|
if isinstance(golden, torch.Tensor):
|
|
if not isinstance(value, torch.Tensor):
|
|
return self._record_mismatch_type_failure('torch.Tensor', value)
|
|
|
|
if value.shape != golden.shape:
|
|
return self._record_failure(
|
|
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
|
|
)
|
|
if value.dtype != golden.dtype:
|
|
return self._record_failure(
|
|
f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
|
)
|
|
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
|
return self._record_failure(
|
|
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
|
|
)
|
|
return
|
|
return self._record_failure(
|
|
f'unexpected golden value of type `{golden.__class__.__name__}`')
|
|
|
|
def _record_failure(self, s: str):
|
|
self.failure_reasons.append(self.context.format_error(s))
|
|
|
|
def _record_mismatch_type_failure(self, expected: str, actual: Any):
|
|
self._record_failure(
|
|
f'expected a value of type `{expected}` but got `{actual.__class__.__name__}`'
|
|
)
|
|
|
|
|
|
|
|
class TraceItemReport:
|
|
"""A report for a single trace item."""
|
|
failure_reasons: List[str]
|
|
|
|
def __init__(self, item: TraceItem, golden_item: TraceItem,
|
|
context: ErrorContext):
|
|
self.item = item
|
|
self.golden_item = golden_item
|
|
self.context = context
|
|
self.failure_reasons = []
|
|
self._evaluate_outcome()
|
|
|
|
@property
|
|
def failed(self):
|
|
return len(self.failure_reasons) != 0
|
|
|
|
def error_str(self):
|
|
return '\n'.join(self.failure_reasons)
|
|
|
|
def _evaluate_outcome(self):
|
|
if self.item.symbol != self.golden_item.symbol:
|
|
self.failure_reasons.append(
|
|
self.context.format_error(
|
|
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
|
|
))
|
|
if len(self.item.inputs) != len(self.golden_item.inputs):
|
|
self.failure_reasons.append(
|
|
self.context.format_error(
|
|
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
|
|
))
|
|
for i, (input, golden_input) in enumerate(
|
|
zip(self.item.inputs, self.golden_item.inputs)):
|
|
value_report = ValueReport(
|
|
input, golden_input,
|
|
self.context.chain(
|
|
f'input #{i} of call to "{self.item.symbol}"'))
|
|
if value_report.failed:
|
|
self.failure_reasons.append(value_report.error_str())
|
|
value_report = ValueReport(
|
|
self.item.output, self.golden_item.output,
|
|
self.context.chain(f'output of call to "{self.item.symbol}"'))
|
|
if value_report.failed:
|
|
self.failure_reasons.append(value_report.error_str())
|
|
|
|
|
|
class SingleTestReport:
|
|
"""A report for a single test."""
|
|
item_reports: Optional[List[TraceItemReport]]
|
|
|
|
def __init__(self, result: TestResult, context: ErrorContext):
|
|
self.result = result
|
|
self.context = context
|
|
self.item_reports = None
|
|
if result.compilation_error is None and result.runtime_error is None:
|
|
self.item_reports = []
|
|
for i, (item, golden_item) in enumerate(
|
|
zip(result.trace, result.golden_trace)):
|
|
self.item_reports.append(
|
|
TraceItemReport(
|
|
item, golden_item,
|
|
context.chain(
|
|
f'trace item #{i} - call to "{item.symbol}"')))
|
|
|
|
@property
|
|
def failed(self):
|
|
if self.result.compilation_error is not None:
|
|
return True
|
|
elif self.result.runtime_error is not None:
|
|
return True
|
|
return any(r.failed for r in self.item_reports)
|
|
|
|
def error_str(self):
|
|
assert self.failed
|
|
f = io.StringIO()
|
|
p = lambda *x: print(*x, file=f)
|
|
if self.result.compilation_error is not None:
|
|
return 'Compilation error: ' + self.result.compilation_error
|
|
elif self.result.runtime_error is not None:
|
|
return 'Runtime error: ' + self.result.runtime_error
|
|
for report in self.item_reports:
|
|
if report.failed:
|
|
p(report.error_str())
|
|
return f.getvalue()
|
|
|
|
|
|
def report_results(results: List[TestResult],
|
|
expected_failures: Set[str],
|
|
verbose: bool = False,
|
|
config: str = ""):
|
|
"""Print a basic error report summarizing various TestResult's.
|
|
|
|
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
|
|
"lit" testing utility. See
|
|
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
|
|
|
|
The `expected_failures` set should contain the names of tests
|
|
(according to their `unique_name`) which are expected to fail.
|
|
The overall passing/failing status of the report requires these to fail
|
|
in order to succeed (this catches cases where things suddenly
|
|
start working).
|
|
|
|
If `verbose` is True, then provide an explanation of what failed.
|
|
|
|
Returns True if the run resulted in any unexpected pass/fail behavior.
|
|
Otherwise False.
|
|
"""
|
|
results_by_outcome = collections.defaultdict(list)
|
|
for result in results:
|
|
report = SingleTestReport(result, ErrorContext.empty())
|
|
expected_failure = result.unique_name in expected_failures
|
|
if expected_failure:
|
|
if report.failed:
|
|
print(f'XFAIL - "{result.unique_name}"')
|
|
results_by_outcome['XFAIL'].append((result, report))
|
|
else:
|
|
print(f'XPASS - "{result.unique_name}"')
|
|
results_by_outcome['XPASS'].append((result, report))
|
|
else:
|
|
if not report.failed:
|
|
print(f'PASS - "{result.unique_name}"')
|
|
results_by_outcome['PASS'].append((result, report))
|
|
else:
|
|
print(f'FAIL - "{result.unique_name}"')
|
|
results_by_outcome['FAIL'].append((result, report))
|
|
|
|
OUTCOME_MEANINGS = collections.OrderedDict()
|
|
OUTCOME_MEANINGS['PASS'] = 'Passed'
|
|
OUTCOME_MEANINGS['FAIL'] = 'Failed'
|
|
OUTCOME_MEANINGS['XFAIL'] = 'Expectedly Failed'
|
|
OUTCOME_MEANINGS['XPASS'] = 'Unexpectedly Passed'
|
|
|
|
had_unexpected_results = len(results_by_outcome['FAIL']) != 0 or len(
|
|
results_by_outcome['XPASS']) != 0
|
|
|
|
if had_unexpected_results:
|
|
print(f'\nUnexpected outcome summary: ({config})')
|
|
|
|
# For FAIL and XPASS (unexpected outcomes), print a summary.
|
|
for outcome, results in results_by_outcome.items():
|
|
# PASS and XFAIL are "good"/"successful" outcomes.
|
|
if outcome == 'PASS' or outcome == 'XFAIL':
|
|
continue
|
|
# If there is nothing to report, be quiet.
|
|
if len(results) == 0:
|
|
continue
|
|
print(f'\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests')
|
|
for result, report in results:
|
|
print(f' {outcome} - "{result.unique_name}"')
|
|
# If the test failed, print the error message.
|
|
if outcome == 'FAIL' and verbose:
|
|
print(textwrap.indent(report.error_str(), ' ' * 8))
|
|
|
|
# Print a summary for easy scanning.
|
|
print('\nSummary:')
|
|
|
|
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
|
|
if results_by_outcome[key]:
|
|
print(f' {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}')
|
|
return had_unexpected_results
|