diff --git a/frontends/pytorch/e2e_testing/torchscript/main.py b/frontends/pytorch/e2e_testing/torchscript/main.py index 30f7f967e..77c34e662 100644 --- a/frontends/pytorch/e2e_testing/torchscript/main.py +++ b/frontends/pytorch/e2e_testing/torchscript/main.py @@ -27,7 +27,7 @@ import vision_models import mlp import quantized_models -def main(): +def _get_argparse(): parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('--config', choices=['native_torch', 'torchscript', 'refbackend'], @@ -41,7 +41,16 @@ Meaning of options: parser.add_argument('--filter', default='.*', help=''' Regular expression specifying which tests to include in this run. ''') - args = parser.parse_args() + parser.add_argument('-v', '--verbose', + default=False, + action='store_true', + help='report test results with additional detail') + return parser + +def main(): + args = _get_argparse().parse_args() + + # Find the selected config. if args.config == 'refbackend': config = RefBackendTestConfig() elif args.config == 'native_torch': @@ -49,6 +58,7 @@ Regular expression specifying which tests to include in this run. elif args.config == 'torchscript': config = TorchScriptTestConfig() + # Find the selected tests, and emit a diagnostic if none are found. tests = [ test for test in GLOBAL_TEST_REGISTRY if re.match(args.filter, test.unique_name) @@ -61,8 +71,12 @@ Regular expression specifying which tests to include in this run. for test in GLOBAL_TEST_REGISTRY: print(test.unique_name) sys.exit(1) + + # Run the tests. results = run_tests(tests, config) - report_results(results) + + # Report the test results. + report_results(results, args.verbose) if __name__ == '__main__': main() diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py index 62207a6a2..be0b31f46 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py @@ -20,7 +20,7 @@ compiling or TorchScript'ing). """ import abc -from typing import Any, Callable, List, NamedTuple, TypeVar +from typing import Any, Callable, List, NamedTuple, Optional, TypeVar import torch @@ -175,10 +175,14 @@ class TestResult(NamedTuple): # those reasons are stronger because we cannot simply extend this # class. unique_name: str # Should match Test.unique_name for corresponding test. + # If compilation failed, a string describing the failure. + # If this is not None, then the `trace` and `golden_trace` fields are None, + # and vice-versa. + compilation_error: Optional[str] # The trace produced by the backend. - trace: Trace + trace: Optional[Trace] # The golden trace which `trace` is expected to match. - golden_trace: Trace + golden_trace: Optional[Trace] class _Tracer: @@ -200,6 +204,9 @@ class _Tracer: raw_outputs = getattr(self.module, name)(*args) if isinstance(raw_outputs, torch.Tensor): outputs = [raw_outputs] + else: + raise Exception( + "unimplemented: non-Tensor output from function") self.trace.append( TraceItem(symbol=name, inputs=args, outputs=outputs)) return raw_outputs @@ -222,11 +229,20 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: for test in tests: golden_trace = _generate_golden_trace(test) # TODO: Precompile everything in parallel. - compiled = config.compile(test.program_factory()) + try: + compiled = config.compile(test.program_factory()) + except Exception as e: + results.append( + TestResult(unique_name=test.unique_name, + compilation_error=str(e), + trace=None, + golden_trace=None)) + continue # TODO: Run in parallel. trace = config.run(compiled, golden_trace) results.append( TestResult(unique_name=test.unique_name, + compilation_error=None, trace=trace, golden_trace=golden_trace)) return results diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py index 1d7a6a5c3..e739e3432 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py @@ -1,60 +1,183 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - """ Utilities for reporting the results of the test framework. """ -from typing import List +from typing import List, Optional + +import io +import textwrap import torch -from .framework import TestResult +from .framework import TestResult, TraceItem -class _TensorStats: + +class TensorSummary: + """A summary of a tensor's contents.""" def __init__(self, tensor): self.min = torch.min(tensor) self.max = torch.max(tensor) self.mean = torch.mean(tensor) + def __str__(self): - return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}' + return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}' -def _print_detailed_tensor_diff(tensor, golden_tensor): - if tensor.size() != golden_tensor.size(): - print( - f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}' - ) - return - print('tensor stats : ', _TensorStats(tensor)) - print('golden tensor stats: ', _TensorStats(golden_tensor)) +class ErrorContext: + """A chained list of error contexts. -def report_results(results: List[TestResult]): - """Provide a basic error report summarizing various TestResult's.""" - any_failed = False + This is useful for tracking errors across multiple levels of detail. + """ + def __init__(self, contexts: List[str]): + self.contexts = contexts + + @staticmethod + def empty(): + """Create an empty error context. + + Used as the top-level context. + """ + return ErrorContext([]) + + def chain(self, additional_context: str): + """Chain an additional context onto the current error context. + """ + return ErrorContext(self.contexts + [additional_context]) + + def format_error(self, s: str): + return '@ ' + '\n@ '.join(self.contexts) + '\n' + 'ERROR: ' + s + + +class ValueReport: + """A report for a single value processed by the program. + + This is currently limited to tensors, but eventually will support + all legal TorchScript types. + """ + def __init__(self, value, golden_value, context: ErrorContext): + self.value = value + self.golden_value = golden_value + self.context = context + + @property + def failed(self): + return not torch.allclose(self.value, self.golden_value) + + def error_str(self): + assert self.failed + if self.value.size() != self.golden_value.size(): + return self.context.format_error( + f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}' + ) + f = io.StringIO() + p = lambda *x: print(*x, file=f) + p('values mismatch') + p('got : ', TensorSummary(self.value)) + p('expected: ', TensorSummary(self.golden_value)) + return self.context.format_error(f.getvalue()) + + +class TraceItemReport: + """A report for a single trace item.""" + failure_reasons: List[str] + + def __init__(self, item: TraceItem, golden_item: TraceItem, + context: ErrorContext): + self.item = item + self.golden_item = golden_item + self.context = context + self.failure_reasons = [] + self._evaluate_outcome() + + @property + def failed(self): + return len(self.failure_reasons) != 0 + + def error_str(self): + return '\n'.join(self.failure_reasons) + + def _evaluate_outcome(self): + if self.item.symbol != self.golden_item.symbol: + self.failure_reasons.append( + self.context.format_error( + f'not invoking the same symbol: got "{self.item.symbol}", expected "{self.golden_item.symbol}"' + )) + if len(self.item.inputs) != len(self.golden_item.inputs): + self.failure_reasons.append( + self.context.format_error( + f'different number of inputs: got "{len(self.item.inputs)}", expected "{len(self.golden_item.inputs)}"' + )) + if len(self.item.outputs) != len(self.golden_item.outputs): + self.failure_reasons.append( + self.context.format_error( + f'different number of outputs: got "{len(self.item.outputs)}", expected "{len(self.golden_item.outputs)}"' + )) + for i, (input, golden_input) in enumerate( + zip(self.item.inputs, self.golden_item.inputs)): + value_report = ValueReport( + input, golden_input, + self.context.chain( + f'input #{i} of call to "{self.item.symbol}"')) + if value_report.failed: + self.failure_reasons.append(value_report.error_str()) + for i, (output, golden_output) in enumerate( + zip(self.item.outputs, self.golden_item.outputs)): + value_report = ValueReport(output, golden_output, + self.context.chain(f'output #{i}')) + if value_report.failed: + self.failure_reasons.append(value_report.error_str()) + + +class SingleTestReport: + """A report for a single test.""" + item_reports: Optional[List[TraceItemReport]] + + def __init__(self, result: TestResult, context: ErrorContext): + self.result = result + self.context = context + self.item_reports = None + if result.compilation_error is None: + self.item_reports = [] + for i, (item, golden_item) in enumerate( + zip(result.trace, result.golden_trace)): + self.item_reports.append( + TraceItemReport( + item, golden_item, + context.chain( + f'trace item #{i} - call to "{item.symbol}"'))) + + @property + def failed(self): + if self.result.compilation_error is not None: + return True + return any(r.failed for r in self.item_reports) + + def error_str(self): + assert self.failed + f = io.StringIO() + p = lambda *x: print(*x, file=f) + if self.result.compilation_error is not None: + return 'compilation error' + self.result.compilation_error + for report in self.item_reports: + if report.failed: + p(report.error_str()) + return f.getvalue() + + +def report_results(results: List[TestResult], verbose: bool = False): + """Provide a basic error report summarizing various TestResult's. + + If `verbose` is True, then provide an explanation of what failed. + """ for result in results: - failed = False - for item_num, (item, golden_item) in enumerate( - zip(result.trace, result.golden_trace)): - assert item.symbol == golden_item.symbol - assert len(item.inputs) == len(golden_item.inputs) - assert len(item.outputs) == len(golden_item.outputs) - for input, golden_input in zip(item.inputs, golden_item.inputs): - assert torch.allclose(input, golden_input) - for output_num, (output, golden_output) in enumerate( - zip(item.outputs, golden_item.outputs)): - # TODO: Refine error message. Things to consider: - # - Very large tensors -- don't spew, but give useful info - # - Smaller tensors / primitives -- want to show exact values - # - Machine parseable format? - if not torch.allclose(output, golden_output): - if not failed: - print('FAILURE "{}"'.format(result.unique_name)) - failed = any_failed = True - print( - f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"' - ) - _print_detailed_tensor_diff(output, golden_output) - if not failed: - print('SUCCESS "{}"'.format(result.unique_name)) + report = SingleTestReport(result, ErrorContext.empty()) + if not report.failed: + print(f'SUCCESS - "{result.unique_name}"') + else: + error_str = '' + if verbose: + error_str = '\n' + textwrap.indent(report.error_str(), ' ') + print(f'FAILURE - "{result.unique_name}"' + error_str) diff --git a/frontends/pytorch/test/torchscript_e2e_test/basic.py b/frontends/pytorch/test/torchscript_e2e_test/basic.py index 8610b14be..762d6e7f1 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/basic.py +++ b/frontends/pytorch/test/torchscript_e2e_test/basic.py @@ -21,13 +21,13 @@ class MmModule(torch.nn.Module): # TODO: Refine messages. -# CHECK: SUCCESS "MmModule_basic" +# CHECK: SUCCESS - "MmModule_basic" @register_test_case(module_factory=lambda: MmModule()) def MmModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) -# CHECK: SUCCESS "MmModule_basic2" +# CHECK: SUCCESS - "MmModule_basic2" @register_test_case(module_factory=lambda: MmModule()) def MmModule_basic2(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) diff --git a/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py b/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py new file mode 100644 index 000000000..fc4041cc0 --- /dev/null +++ b/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py @@ -0,0 +1,45 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils +from torch_mlir.torchscript.e2e_test.reporting import report_results +from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig + + +class MmModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, t): + # Static type error that will fail TorchScript compilation -- function + # that returns tensor along one path and int along another. + if t.item() > 0: + return torch.tensor([]) + else: + return 3 + + +# CHECK: FAILURE - "MmModule_basic" +# CHECK: compilation error +# Assume that the diagnostic from the TorchScript compiler will at least contain +# the offending "return 3". +# CHECK: return 3 +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_basic(module, tu: TestUtils): + module.forward(torch.ones([])) + + +def main(): + config = TorchScriptTestConfig() + results = run_tests(GLOBAL_TEST_REGISTRY, config) + report_results(results, verbose=True) + + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py index 4344ce356..a3f2488cf 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py +++ b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py @@ -26,11 +26,12 @@ class MmModule(torch.nn.Module): # TODO: Refine error messages. -# CHECK: FAILURE "MmModule_basic" -# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward" -# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}} -# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}} -# CHECK-NOT: ALL PASS +# CHECK: FAILURE - "MmModule_basic" +# CHECK: @ trace item #0 - call to "forward" +# CHECK: @ output #0 +# CHECK: ERROR: values mismatch +# CHECK: got : Tensor with min={{.*}}, max={{.*}}, mean={{.*}} +# CHECK: expected: Tensor with min={{.*}}, max={{.*}}, mean={{.*}} @register_test_case(module_factory=lambda: MmModule()) def MmModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) @@ -39,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils): def main(): config = TorchScriptTestConfig() results = run_tests(GLOBAL_TEST_REGISTRY, config) - report_results(results) + report_results(results, verbose=True) if __name__ == '__main__': diff --git a/lib/Backend/Common/VerifyBackendContract.cpp b/lib/Backend/Common/VerifyBackendContract.cpp index e03ca00d6..b8dbb8582 100644 --- a/lib/Backend/Common/VerifyBackendContract.cpp +++ b/lib/Backend/Common/VerifyBackendContract.cpp @@ -69,7 +69,9 @@ class VerifyBackendContractPass RewritePatternSet patterns(context); if (failed(applyFullConversion(module, target, std::move(patterns)))) { - module.emitError() + // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics + // doesn't unnecessarily spew out the entire module. + emitError(module.getLoc()) << "Module does not conform to npcomp's backend contract. See " "dialect conversion legality information above."; return signalPassFailure();