mirror of https://github.com/llvm/torch-mlir
Rewrite error reporting of e2e tests.
This now gives [much nicer output](https://gist.github.com/silvasean/f048e0f37b04542dae6469b86802bb3e). Embarrassingly, we previously couldn't even report failures for two different tests, and weren't able to report on compilation failures (besides just crashing).pull/217/head
parent
d66e8fe1f8
commit
b7b7fd4959
|
@ -27,7 +27,7 @@ import vision_models
|
||||||
import mlp
|
import mlp
|
||||||
import quantized_models
|
import quantized_models
|
||||||
|
|
||||||
def main():
|
def _get_argparse():
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('--config',
|
parser.add_argument('--config',
|
||||||
choices=['native_torch', 'torchscript', 'refbackend'],
|
choices=['native_torch', 'torchscript', 'refbackend'],
|
||||||
|
@ -41,7 +41,16 @@ Meaning of options:
|
||||||
parser.add_argument('--filter', default='.*', help='''
|
parser.add_argument('--filter', default='.*', help='''
|
||||||
Regular expression specifying which tests to include in this run.
|
Regular expression specifying which tests to include in this run.
|
||||||
''')
|
''')
|
||||||
args = parser.parse_args()
|
parser.add_argument('-v', '--verbose',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='report test results with additional detail')
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = _get_argparse().parse_args()
|
||||||
|
|
||||||
|
# Find the selected config.
|
||||||
if args.config == 'refbackend':
|
if args.config == 'refbackend':
|
||||||
config = RefBackendTestConfig()
|
config = RefBackendTestConfig()
|
||||||
elif args.config == 'native_torch':
|
elif args.config == 'native_torch':
|
||||||
|
@ -49,6 +58,7 @@ Regular expression specifying which tests to include in this run.
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
|
|
||||||
|
# Find the selected tests, and emit a diagnostic if none are found.
|
||||||
tests = [
|
tests = [
|
||||||
test for test in GLOBAL_TEST_REGISTRY
|
test for test in GLOBAL_TEST_REGISTRY
|
||||||
if re.match(args.filter, test.unique_name)
|
if re.match(args.filter, test.unique_name)
|
||||||
|
@ -61,8 +71,12 @@ Regular expression specifying which tests to include in this run.
|
||||||
for test in GLOBAL_TEST_REGISTRY:
|
for test in GLOBAL_TEST_REGISTRY:
|
||||||
print(test.unique_name)
|
print(test.unique_name)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Run the tests.
|
||||||
results = run_tests(tests, config)
|
results = run_tests(tests, config)
|
||||||
report_results(results)
|
|
||||||
|
# Report the test results.
|
||||||
|
report_results(results, args.verbose)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -20,7 +20,7 @@ compiling or TorchScript'ing).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import Any, Callable, List, NamedTuple, TypeVar
|
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -175,10 +175,14 @@ class TestResult(NamedTuple):
|
||||||
# those reasons are stronger because we cannot simply extend this
|
# those reasons are stronger because we cannot simply extend this
|
||||||
# class.
|
# class.
|
||||||
unique_name: str # Should match Test.unique_name for corresponding test.
|
unique_name: str # Should match Test.unique_name for corresponding test.
|
||||||
|
# If compilation failed, a string describing the failure.
|
||||||
|
# If this is not None, then the `trace` and `golden_trace` fields are None,
|
||||||
|
# and vice-versa.
|
||||||
|
compilation_error: Optional[str]
|
||||||
# The trace produced by the backend.
|
# The trace produced by the backend.
|
||||||
trace: Trace
|
trace: Optional[Trace]
|
||||||
# The golden trace which `trace` is expected to match.
|
# The golden trace which `trace` is expected to match.
|
||||||
golden_trace: Trace
|
golden_trace: Optional[Trace]
|
||||||
|
|
||||||
|
|
||||||
class _Tracer:
|
class _Tracer:
|
||||||
|
@ -200,6 +204,9 @@ class _Tracer:
|
||||||
raw_outputs = getattr(self.module, name)(*args)
|
raw_outputs = getattr(self.module, name)(*args)
|
||||||
if isinstance(raw_outputs, torch.Tensor):
|
if isinstance(raw_outputs, torch.Tensor):
|
||||||
outputs = [raw_outputs]
|
outputs = [raw_outputs]
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"unimplemented: non-Tensor output from function")
|
||||||
self.trace.append(
|
self.trace.append(
|
||||||
TraceItem(symbol=name, inputs=args, outputs=outputs))
|
TraceItem(symbol=name, inputs=args, outputs=outputs))
|
||||||
return raw_outputs
|
return raw_outputs
|
||||||
|
@ -222,11 +229,20 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||||
for test in tests:
|
for test in tests:
|
||||||
golden_trace = _generate_golden_trace(test)
|
golden_trace = _generate_golden_trace(test)
|
||||||
# TODO: Precompile everything in parallel.
|
# TODO: Precompile everything in parallel.
|
||||||
|
try:
|
||||||
compiled = config.compile(test.program_factory())
|
compiled = config.compile(test.program_factory())
|
||||||
|
except Exception as e:
|
||||||
|
results.append(
|
||||||
|
TestResult(unique_name=test.unique_name,
|
||||||
|
compilation_error=str(e),
|
||||||
|
trace=None,
|
||||||
|
golden_trace=None))
|
||||||
|
continue
|
||||||
# TODO: Run in parallel.
|
# TODO: Run in parallel.
|
||||||
trace = config.run(compiled, golden_trace)
|
trace = config.run(compiled, golden_trace)
|
||||||
results.append(
|
results.append(
|
||||||
TestResult(unique_name=test.unique_name,
|
TestResult(unique_name=test.unique_name,
|
||||||
|
compilation_error=None,
|
||||||
trace=trace,
|
trace=trace,
|
||||||
golden_trace=golden_trace))
|
golden_trace=golden_trace))
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -1,60 +1,183 @@
|
||||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Utilities for reporting the results of the test framework.
|
Utilities for reporting the results of the test framework.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import io
|
||||||
|
import textwrap
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .framework import TestResult
|
from .framework import TestResult, TraceItem
|
||||||
|
|
||||||
class _TensorStats:
|
|
||||||
|
class TensorSummary:
|
||||||
|
"""A summary of a tensor's contents."""
|
||||||
def __init__(self, tensor):
|
def __init__(self, tensor):
|
||||||
self.min = torch.min(tensor)
|
self.min = torch.min(tensor)
|
||||||
self.max = torch.max(tensor)
|
self.max = torch.max(tensor)
|
||||||
self.mean = torch.mean(tensor)
|
self.mean = torch.mean(tensor)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
||||||
|
|
||||||
|
|
||||||
def _print_detailed_tensor_diff(tensor, golden_tensor):
|
class ErrorContext:
|
||||||
if tensor.size() != golden_tensor.size():
|
"""A chained list of error contexts.
|
||||||
print(
|
|
||||||
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
This is useful for tracking errors across multiple levels of detail.
|
||||||
|
"""
|
||||||
|
def __init__(self, contexts: List[str]):
|
||||||
|
self.contexts = contexts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty():
|
||||||
|
"""Create an empty error context.
|
||||||
|
|
||||||
|
Used as the top-level context.
|
||||||
|
"""
|
||||||
|
return ErrorContext([])
|
||||||
|
|
||||||
|
def chain(self, additional_context: str):
|
||||||
|
"""Chain an additional context onto the current error context.
|
||||||
|
"""
|
||||||
|
return ErrorContext(self.contexts + [additional_context])
|
||||||
|
|
||||||
|
def format_error(self, s: str):
|
||||||
|
return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s
|
||||||
|
|
||||||
|
|
||||||
|
class ValueReport:
|
||||||
|
"""A report for a single value processed by the program.
|
||||||
|
|
||||||
|
This is currently limited to tensors, but eventually will support
|
||||||
|
all legal TorchScript types.
|
||||||
|
"""
|
||||||
|
def __init__(self, value, golden_value, context: ErrorContext):
|
||||||
|
self.value = value
|
||||||
|
self.golden_value = golden_value
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed(self):
|
||||||
|
return not torch.allclose(self.value, self.golden_value)
|
||||||
|
|
||||||
|
def error_str(self):
|
||||||
|
assert self.failed
|
||||||
|
if self.value.size() != self.golden_value.size():
|
||||||
|
return self.context.format_error(
|
||||||
|
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
||||||
)
|
)
|
||||||
return
|
f = io.StringIO()
|
||||||
print('tensor stats : ', _TensorStats(tensor))
|
p = lambda *x: print(*x, file=f)
|
||||||
print('golden tensor stats: ', _TensorStats(golden_tensor))
|
p('values mismatch')
|
||||||
|
p('got : ', TensorSummary(self.value))
|
||||||
|
p('expected: ', TensorSummary(self.golden_value))
|
||||||
|
return self.context.format_error(f.getvalue())
|
||||||
|
|
||||||
def report_results(results: List[TestResult]):
|
|
||||||
"""Provide a basic error report summarizing various TestResult's."""
|
class TraceItemReport:
|
||||||
any_failed = False
|
"""A report for a single trace item."""
|
||||||
for result in results:
|
failure_reasons: List[str]
|
||||||
failed = False
|
|
||||||
for item_num, (item, golden_item) in enumerate(
|
def __init__(self, item: TraceItem, golden_item: TraceItem,
|
||||||
|
context: ErrorContext):
|
||||||
|
self.item = item
|
||||||
|
self.golden_item = golden_item
|
||||||
|
self.context = context
|
||||||
|
self.failure_reasons = []
|
||||||
|
self._evaluate_outcome()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed(self):
|
||||||
|
return len(self.failure_reasons) != 0
|
||||||
|
|
||||||
|
def error_str(self):
|
||||||
|
return '\n'.join(self.failure_reasons)
|
||||||
|
|
||||||
|
def _evaluate_outcome(self):
|
||||||
|
if self.item.symbol != self.golden_item.symbol:
|
||||||
|
self.failure_reasons.append(
|
||||||
|
self.context.format_error(
|
||||||
|
f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"'
|
||||||
|
))
|
||||||
|
if len(self.item.inputs) != len(self.golden_item.inputs):
|
||||||
|
self.failure_reasons.append(
|
||||||
|
self.context.format_error(
|
||||||
|
f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"'
|
||||||
|
))
|
||||||
|
if len(self.item.outputs) != len(self.golden_item.outputs):
|
||||||
|
self.failure_reasons.append(
|
||||||
|
self.context.format_error(
|
||||||
|
f'different number of outputs: got "{len(self.item.outputs)}", expected "{len(self.golden_item.outputs)}"'
|
||||||
|
))
|
||||||
|
for i, (input, golden_input) in enumerate(
|
||||||
|
zip(self.item.inputs, self.golden_item.inputs)):
|
||||||
|
value_report = ValueReport(
|
||||||
|
input, golden_input,
|
||||||
|
self.context.chain(
|
||||||
|
f'input #{i} of call to "{self.item.symbol}"'))
|
||||||
|
if value_report.failed:
|
||||||
|
self.failure_reasons.append(value_report.error_str())
|
||||||
|
for i, (output, golden_output) in enumerate(
|
||||||
|
zip(self.item.outputs, self.golden_item.outputs)):
|
||||||
|
value_report = ValueReport(output, golden_output,
|
||||||
|
self.context.chain(f'output #{i}'))
|
||||||
|
if value_report.failed:
|
||||||
|
self.failure_reasons.append(value_report.error_str())
|
||||||
|
|
||||||
|
|
||||||
|
class SingleTestReport:
|
||||||
|
"""A report for a single test."""
|
||||||
|
item_reports: Optional[List[TraceItemReport]]
|
||||||
|
|
||||||
|
def __init__(self, result: TestResult, context: ErrorContext):
|
||||||
|
self.result = result
|
||||||
|
self.context = context
|
||||||
|
self.item_reports = None
|
||||||
|
if result.compilation_error is None:
|
||||||
|
self.item_reports = []
|
||||||
|
for i, (item, golden_item) in enumerate(
|
||||||
zip(result.trace, result.golden_trace)):
|
zip(result.trace, result.golden_trace)):
|
||||||
assert item.symbol == golden_item.symbol
|
self.item_reports.append(
|
||||||
assert len(item.inputs) == len(golden_item.inputs)
|
TraceItemReport(
|
||||||
assert len(item.outputs) == len(golden_item.outputs)
|
item, golden_item,
|
||||||
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
context.chain(
|
||||||
assert torch.allclose(input, golden_input)
|
f'trace item #{i} - call to "{item.symbol}"')))
|
||||||
for output_num, (output, golden_output) in enumerate(
|
|
||||||
zip(item.outputs, golden_item.outputs)):
|
@property
|
||||||
# TODO: Refine error message. Things to consider:
|
def failed(self):
|
||||||
# - Very large tensors -- don't spew, but give useful info
|
if self.result.compilation_error is not None:
|
||||||
# - Smaller tensors / primitives -- want to show exact values
|
return True
|
||||||
# - Machine parseable format?
|
return any(r.failed for r in self.item_reports)
|
||||||
if not torch.allclose(output, golden_output):
|
|
||||||
if not failed:
|
def error_str(self):
|
||||||
print('FAILURE "{}"'.format(result.unique_name))
|
assert self.failed
|
||||||
failed = any_failed = True
|
f = io.StringIO()
|
||||||
print(
|
p = lambda *x: print(*x, file=f)
|
||||||
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
if self.result.compilation_error is not None:
|
||||||
)
|
return 'compilation error' + self.result.compilation_error
|
||||||
_print_detailed_tensor_diff(output, golden_output)
|
for report in self.item_reports:
|
||||||
if not failed:
|
if report.failed:
|
||||||
print('SUCCESS "{}"'.format(result.unique_name))
|
p(report.error_str())
|
||||||
|
return f.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def report_results(results: List[TestResult], verbose: bool = False):
|
||||||
|
"""Provide a basic error report summarizing various TestResult's.
|
||||||
|
|
||||||
|
If `verbose` is True, then provide an explanation of what failed.
|
||||||
|
"""
|
||||||
|
for result in results:
|
||||||
|
report = SingleTestReport(result, ErrorContext.empty())
|
||||||
|
if not report.failed:
|
||||||
|
print(f'SUCCESS - "{result.unique_name}"')
|
||||||
|
else:
|
||||||
|
error_str = ''
|
||||||
|
if verbose:
|
||||||
|
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||||
|
print(f'FAILURE - "{result.unique_name}"' + error_str)
|
||||||
|
|
|
@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# TODO: Refine messages.
|
# TODO: Refine messages.
|
||||||
# CHECK: SUCCESS "MmModule_basic"
|
# CHECK: SUCCESS - "MmModule_basic"
|
||||||
@register_test_case(module_factory=lambda: MmModule())
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
def MmModule_basic(module, tu: TestUtils):
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
# CHECK: SUCCESS "MmModule_basic2"
|
# CHECK: SUCCESS - "MmModule_basic2"
|
||||||
@register_test_case(module_factory=lambda: MmModule())
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
def MmModule_basic2(module, tu: TestUtils):
|
def MmModule_basic2(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
|
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MmModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
# Static type error that will fail TorchScript compilation -- function
|
||||||
|
# that returns tensor along one path and int along another.
|
||||||
|
if t.item() > 0:
|
||||||
|
return torch.tensor([])
|
||||||
|
else:
|
||||||
|
return 3
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK: FAILURE - "MmModule_basic"
|
||||||
|
# CHECK: compilation error
|
||||||
|
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
||||||
|
# the offending "return 3".
|
||||||
|
# CHECK: return 3
|
||||||
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.ones([]))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = TorchScriptTestConfig()
|
||||||
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
|
report_results(results, verbose=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -26,11 +26,12 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# TODO: Refine error messages.
|
# TODO: Refine error messages.
|
||||||
# CHECK: FAILURE "MmModule_basic"
|
# CHECK: FAILURE - "MmModule_basic"
|
||||||
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
# CHECK: @ trace item #0 - call to "forward"
|
||||||
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}}
|
# CHECK: @ output #0
|
||||||
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}}
|
# CHECK: ERROR: values mismatch
|
||||||
# CHECK-NOT: ALL PASS
|
# CHECK: got : Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
|
||||||
|
# CHECK: expected: Tensor with min={{.*}}, max={{.*}}, mean={{.*}}
|
||||||
@register_test_case(module_factory=lambda: MmModule())
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
def MmModule_basic(module, tu: TestUtils):
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
@ -39,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
|
||||||
def main():
|
def main():
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
report_results(results)
|
report_results(results, verbose=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -69,7 +69,9 @@ class VerifyBackendContractPass
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||||
module.emitError()
|
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||||
|
// doesn't unnecessarily spew out the entire module.
|
||||||
|
emitError(module.getLoc())
|
||||||
<< "Module does not conform to npcomp's backend contract. See "
|
<< "Module does not conform to npcomp's backend contract. See "
|
||||||
"dialect conversion legality information above.";
|
"dialect conversion legality information above.";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
Loading…
Reference in New Issue