diff --git a/frontends/pytorch/e2e_testing/torchscript/main.py b/frontends/pytorch/e2e_testing/torchscript/main.py index 214be1c76..993caa1a1 100644 --- a/frontends/pytorch/e2e_testing/torchscript/main.py +++ b/frontends/pytorch/e2e_testing/torchscript/main.py @@ -4,7 +4,8 @@ import argparse -from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results +from torch_mlir.torchscript.e2e_test.framework import run_tests +from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY # Available test configs. @@ -20,6 +21,7 @@ from torch_mlir.torchscript.e2e_test.configs import ( # be run from a specific directory. # TODO: Find out best practices for python "main" files. import basic +import vision_models def main(): parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') @@ -29,7 +31,7 @@ def main(): help=''' Meaning of options: "refbackend": run through npcomp's RefBackend. -"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic). +"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). ''') args = parser.parse_args() diff --git a/frontends/pytorch/e2e_testing/torchscript/vision_models.py b/frontends/pytorch/e2e_testing/torchscript/vision_models.py new file mode 100644 index 000000000..a6c2ab0c0 --- /dev/null +++ b/frontends/pytorch/e2e_testing/torchscript/vision_models.py @@ -0,0 +1,30 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +import torchvision.models as models + +from torch_mlir.torchscript.e2e_test.framework import TestUtils +from torch_mlir.torchscript.e2e_test.registry import register_test_case +from torch_mlir.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class Resnet18Module(torch.nn.Module): + def __init__(self): + super().__init__() + # Reset seed to make model deterministic. + torch.manual_seed(0) + self.resnet = models.resnet18() + @export + @annotate_args([ + None, + ([-1, 3, -1, -1], torch.float32), + ]) + def forward(self, img): + return self.resnet.forward(img) + +@register_test_case(module_factory=lambda: Resnet18Module()) +def Resnet18Module_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 224, 224)) diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py index b869ea6b0..62207a6a2 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py @@ -230,31 +230,3 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: trace=trace, golden_trace=golden_trace)) return results - - -def report_results(results: List[TestResult]): - """Provide a basic error report summarizing various TestResult's.""" - for result in results: - failed = False - for item_num, (item, golden_item) in enumerate( - zip(result.trace, result.golden_trace)): - assert item.symbol == golden_item.symbol - assert len(item.inputs) == len(golden_item.inputs) - assert len(item.outputs) == len(golden_item.outputs) - for input, golden_input in zip(item.inputs, golden_item.inputs): - assert torch.allclose(input, golden_input) - for output_num, (output, golden_output) in enumerate( - zip(item.outputs, golden_item.outputs)): - # TODO: Refine error message. Things to consider: - # - Very large tensors -- don't spew, but give useful info - # - Smaller tensors / primitives -- want to show exact values - # - Machine parseable format? - if not torch.allclose(output, golden_output): - print( - f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"' - ) - failed = True - if failed: - print('FAILURE "{}"'.format(result.unique_name)) - else: - print('SUCCESS "{}"'.format(result.unique_name)) diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py new file mode 100644 index 000000000..1d7a6a5c3 --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py @@ -0,0 +1,60 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Utilities for reporting the results of the test framework. +""" + +from typing import List + +import torch + +from .framework import TestResult + +class _TensorStats: + def __init__(self, tensor): + self.min = torch.min(tensor) + self.max = torch.max(tensor) + self.mean = torch.mean(tensor) + def __str__(self): + return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}' + + +def _print_detailed_tensor_diff(tensor, golden_tensor): + if tensor.size() != golden_tensor.size(): + print( + f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}' + ) + return + print('tensor stats : ', _TensorStats(tensor)) + print('golden tensor stats: ', _TensorStats(golden_tensor)) + +def report_results(results: List[TestResult]): + """Provide a basic error report summarizing various TestResult's.""" + any_failed = False + for result in results: + failed = False + for item_num, (item, golden_item) in enumerate( + zip(result.trace, result.golden_trace)): + assert item.symbol == golden_item.symbol + assert len(item.inputs) == len(golden_item.inputs) + assert len(item.outputs) == len(golden_item.outputs) + for input, golden_input in zip(item.inputs, golden_item.inputs): + assert torch.allclose(input, golden_input) + for output_num, (output, golden_output) in enumerate( + zip(item.outputs, golden_item.outputs)): + # TODO: Refine error message. Things to consider: + # - Very large tensors -- don't spew, but give useful info + # - Smaller tensors / primitives -- want to show exact values + # - Machine parseable format? + if not torch.allclose(output, golden_output): + if not failed: + print('FAILURE "{}"'.format(result.unique_name)) + failed = any_failed = True + print( + f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"' + ) + _print_detailed_tensor_diff(output, golden_output) + if not failed: + print('SUCCESS "{}"'.format(result.unique_name)) diff --git a/frontends/pytorch/test/torchscript_e2e_test/basic.py b/frontends/pytorch/test/torchscript_e2e_test/basic.py index 502e805c3..8610b14be 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/basic.py +++ b/frontends/pytorch/test/torchscript_e2e_test/basic.py @@ -6,7 +6,8 @@ import torch -from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils +from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils +from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig @@ -26,6 +27,12 @@ def MmModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) +# CHECK: SUCCESS "MmModule_basic2" +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_basic2(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + + def main(): config = TorchScriptTestConfig() results = run_tests(GLOBAL_TEST_REGISTRY, config) diff --git a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py index e3bbc8f22..4344ce356 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py +++ b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py @@ -6,7 +6,8 @@ import torch -from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils +from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils +from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig @@ -25,8 +26,11 @@ class MmModule(torch.nn.Module): # TODO: Refine error messages. -# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward" # CHECK: FAILURE "MmModule_basic" +# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward" +# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}} +# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}} +# CHECK-NOT: ALL PASS @register_test_case(module_factory=lambda: MmModule()) def MmModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4))