torch-mlir/python/torch_mlir_e2e_test/reporting.py

338 lines
13 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""
Utilities for reporting the results of the test framework.
"""
from typing import Any, List, Optional, Set
import collections
import io
import textwrap
import torch
from .framework import TestResult, TraceItem
class TensorSummary:
"""A summary of a tensor's contents."""
def __init__(self, tensor):
self.min = torch.min(tensor.type(torch.float64))
self.max = torch.max(tensor.type(torch.float64))
self.mean = torch.mean(tensor.type(torch.float64))
self.shape = list(tensor.shape)
self.dtype = tensor.dtype
def __str__(self):
return f'Tensor with shape={self.shape}, dtype={self.dtype}, min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
class ErrorContext:
"""A chained list of error contexts.
This is useful for tracking errors across multiple levels of detail.
"""
def __init__(self, contexts: List[str]):
self.contexts = contexts
@staticmethod
def empty():
"""Create an empty error context.
Used as the top-level context.
"""
return ErrorContext([])
def chain(self, additional_context: str):
"""Chain an additional context onto the current error context.
"""
return ErrorContext(self.contexts + [additional_context])
def format_error(self, s: str):
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
class ValueReport:
"""A report for a single value processed by the program.
"""
def __init__(self, value, golden_value, context: ErrorContext):
self.value = value
self.golden_value = golden_value
self.context = context
self.failure_reasons = []
self._evaluate_outcome()
@property
def failed(self):
return len(self.failure_reasons) != 0
def error_str(self):
return '\n'.join(self.failure_reasons)
def _evaluate_outcome(self):
value, golden = self.value, self.golden_value
if isinstance(golden, float):
if not isinstance(value, float):
return self._record_mismatch_type_failure('float', value)
if abs(value - golden) / golden > 1e-4:
return self._record_failure(
f'value ({value!r}) is not close to golden value ({golden!r})'
)
return
if isinstance(golden, int):
if not isinstance(value, int):
return self._record_mismatch_type_failure('int', value)
if value != golden:
return self._record_failure(
f'value ({value!r}) is not equal to golden value ({golden!r})'
)
return
if isinstance(golden, str):
if not isinstance(value, str):
return self._record_mismatch_type_failure('str', value)
if value != golden:
return self._record_failure(
f'value ({value!r}) is not equal to golden value ({golden!r})'
)
return
if isinstance(golden, tuple):
if not isinstance(value, tuple):
return self._record_mismatch_type_failure('tuple', value)
if len(value) != len(golden):
return self._record_failure(
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
)
reports = [
ValueReport(v, g, self.context.chain(f'tuple element {i}'))
for i, (v, g) in enumerate(zip(value, golden))
]
for report in reports:
if report.failed:
self.failure_reasons.extend(report.failure_reasons)
return
if isinstance(golden, list):
if not isinstance(value, list):
return self._record_mismatch_type_failure('list', value)
if len(value) != len(golden):
return self._record_failure(
f'value ({len(value)!r}) is not equal to golden value ({len(golden)!r})'
)
reports = [
ValueReport(v, g, self.context.chain(f'list element {i}'))
for i, (v, g) in enumerate(zip(value, golden))
]
for report in reports:
if report.failed:
self.failure_reasons.extend(report.failure_reasons)
return
if isinstance(golden, dict):
if not isinstance(value, dict):
return self._record_mismatch_type_failure('dict', value)
gkeys = list(sorted(golden.keys()))
vkeys = list(sorted(value.keys()))
if gkeys != vkeys:
return self._record_failure(
f'dict keys ({vkeys!r}) are not equal to golden keys ({gkeys!r})'
)
reports = [
ValueReport(value[k], golden[k],
self.context.chain(f'dict element at key {k!r}'))
for k in gkeys
]
for report in reports:
if report.failed:
self.failure_reasons.extend(report.failure_reasons)
return
if isinstance(golden, torch.Tensor):
if not isinstance(value, torch.Tensor):
return self._record_mismatch_type_failure('torch.Tensor', value)
if value.shape != golden.shape:
return self._record_failure(
f'shape ({value.shape}) is not equal to golden shape ({golden.shape})'
)
if value.dtype != golden.dtype:
return self._record_failure(
f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
)
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
return self._record_failure(
f'value ({TensorSummary(value)}) is not close to golden value ({TensorSummary(golden)})'
)
return
return self._record_failure(
f'unexpected golden value of type `{golden.__class__.__name__}`')
def _record_failure(self, s: str):
self.failure_reasons.append(self.context.format_error(s))
def _record_mismatch_type_failure(self, expected: str, actual: Any):
self._record_failure(
f'expected a value of type `{expected}` but got `{actual.__class__.__name__}`'
)
class TraceItemReport:
"""A report for a single trace item."""
failure_reasons: List[str]
def __init__(self, item: TraceItem, golden_item: TraceItem,
context: ErrorContext):
self.item = item
self.golden_item = golden_item
self.context = context
self.failure_reasons = []
self._evaluate_outcome()
@property
def failed(self):
return len(self.failure_reasons) != 0
def error_str(self):
return '\n'.join(self.failure_reasons)
def _evaluate_outcome(self):
if self.item.symbol != self.golden_item.symbol:
self.failure_reasons.append(
self.context.format_error(
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
))
if len(self.item.inputs) != len(self.golden_item.inputs):
self.failure_reasons.append(
self.context.format_error(
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
))
for i, (input, golden_input) in enumerate(
zip(self.item.inputs, self.golden_item.inputs)):
value_report = ValueReport(
input, golden_input,
self.context.chain(
f'input #{i} of call to "{self.item.symbol}"'))
if value_report.failed:
self.failure_reasons.append(value_report.error_str())
value_report = ValueReport(
self.item.output, self.golden_item.output,
self.context.chain(f'output of call to "{self.item.symbol}"'))
if value_report.failed:
self.failure_reasons.append(value_report.error_str())
class SingleTestReport:
"""A report for a single test."""
item_reports: Optional[List[TraceItemReport]]
def __init__(self, result: TestResult, context: ErrorContext):
self.result = result
self.context = context
self.item_reports = None
2021-09-24 05:50:37 +08:00
if result.compilation_error is None and result.runtime_error is None:
self.item_reports = []
for i, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
self.item_reports.append(
TraceItemReport(
item, golden_item,
context.chain(
f'trace item #{i} - call to "{item.symbol}"')))
@property
def failed(self):
if self.result.compilation_error is not None:
return True
2021-09-24 05:50:37 +08:00
elif self.result.runtime_error is not None:
return True
return any(r.failed for r in self.item_reports)
def error_str(self):
assert self.failed
f = io.StringIO()
p = lambda *x: print(*x, file=f)
if self.result.compilation_error is not None:
Add E2E support for tests with heavy dependencies (heavydep tests). The tests use the same (pure-Python) test framework as the normal torchscript_e2e_test.sh, but the tests are added in `build_tools/torchscript_e2e_heavydep_tests` instead of `frontends/pytorch/e2e_testing/torchscript`. Any needed dependencies can easily be configured in generate_serialized_tests.sh. We add an initial machine translation model with a complex set of dependencies to seed the curriculum there. I verified that this model gets to the point of MLIR import (it fails there with a segfault due to not being able to import the "Any" type). This required moving a few files from the `torch_mlir` Python module into multiple modules to isolate the code that depends on our C++ extensions (which now live in `torch_mlir` and `torch_mlir_torchscript_e2e_test_configs`) from the pure Python code (which now lives in `torch_mlir_torchscript`). This is an entirely mechanical change, and lots of imports needed to be updated. The dependency graph is: ``` torch_mlir_torchscript_e2e_test_configs / | / | / | V V torch_mlir_torchscript torch_mlir ``` The `torch_mlir_torchscript_e2e_test_configs` are then dependency-injected into the `torch_mlir_torchscript` modules to successfully assemble a working test harness (the code was already structured this way, but this new file organization allows the isolation from C++ code to actually happen). This isolation is critical to allowing the serialized programs to be transported across PyTorch versions and for the test harness to be used seamlessly to generate the heavydep tests. Also: - Extend `_Tracer` class to support nested property (submodule) accesses. Recommended review order: - "user-level" docs in README.md - code in `build_tools/torchscript_e2e_heavydep_tests`. - changes in `torch_mlir_torchscript/e2e_test/framework.py` - misc mechanical changes.
2021-07-10 03:22:45 +08:00
return 'Compilation error: ' + self.result.compilation_error
2021-09-24 05:50:37 +08:00
elif self.result.runtime_error is not None:
return 'Runtime error: ' + self.result.runtime_error
for report in self.item_reports:
if report.failed:
p(report.error_str())
return f.getvalue()
def report_results(results: List[TestResult],
expected_failures: Set[str],
verbose: bool = False,
config: str = ""):
"""Print a basic error report summarizing various TestResult's.
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
"lit" testing utility. See
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
The `expected_failures` set should contain the names of tests
(according to their `unique_name`) which are expected to fail.
The overall passing/failing status of the report requires these to fail
in order to succeed (this catches cases where things suddenly
start working).
If `verbose` is True, then provide an explanation of what failed.
Returns True if the run resulted in any unexpected pass/fail behavior.
Otherwise False.
"""
results_by_outcome = collections.defaultdict(list)
for result in results:
report = SingleTestReport(result, ErrorContext.empty())
expected_failure = result.unique_name in expected_failures
if expected_failure:
if report.failed:
print(f'XFAIL - "{result.unique_name}"')
results_by_outcome['XFAIL'].append((result, report))
else:
print(f'XPASS - "{result.unique_name}"')
results_by_outcome['XPASS'].append((result, report))
else:
if not report.failed:
print(f'PASS - "{result.unique_name}"')
results_by_outcome['PASS'].append((result, report))
else:
print(f'FAIL - "{result.unique_name}"')
results_by_outcome['FAIL'].append((result, report))
OUTCOME_MEANINGS = collections.OrderedDict()
OUTCOME_MEANINGS['PASS'] = 'Passed'
OUTCOME_MEANINGS['FAIL'] = 'Failed'
OUTCOME_MEANINGS['XFAIL'] = 'Expectedly Failed'
OUTCOME_MEANINGS['XPASS'] = 'Unexpectedly Passed'
had_unexpected_results = len(results_by_outcome['FAIL']) != 0 or len(
results_by_outcome['XPASS']) != 0
if had_unexpected_results:
print(f'\nUnexpected outcome summary: ({config})')
# For FAIL and XPASS (unexpected outcomes), print a summary.
for outcome, results in results_by_outcome.items():
# PASS and XFAIL are "good"/"successful" outcomes.
if outcome == 'PASS' or outcome == 'XFAIL':
continue
# If there is nothing to report, be quiet.
if len(results) == 0:
continue
print(f'\n****** {OUTCOME_MEANINGS[outcome]} tests - {len(results)} tests')
for result, report in results:
print(f' {outcome} - "{result.unique_name}"')
# If the test failed, print the error message.
if outcome == 'FAIL' and verbose:
print(textwrap.indent(report.error_str(), ' ' * 8))
# Print a summary for easy scanning.
print('\nSummary:')
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
if results_by_outcome[key]:
print(f' {OUTCOME_MEANINGS[key]}: {len(results_by_outcome[key])}')
return had_unexpected_results