mirror of https://github.com/llvm/torch-mlir
Add vision models (resnet18 to start).
Also, - improve error reporting of e2e framework.pull/209/head
parent
390d39a96c
commit
8f96901943
|
@ -4,7 +4,8 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests
|
||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||
|
||||
# Available test configs.
|
||||
|
@ -20,6 +21,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
|
|||
# be run from a specific directory.
|
||||
# TODO: Find out best practices for python "main" files.
|
||||
import basic
|
||||
import vision_models
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
|
@ -29,7 +31,7 @@ def main():
|
|||
help='''
|
||||
Meaning of options:
|
||||
"refbackend": run through npcomp's RefBackend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic).
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
''')
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Resnet18Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
torch.manual_seed(0)
|
||||
self.resnet = models.resnet18()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, 3, -1, -1], torch.float32),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
@register_test_case(module_factory=lambda: Resnet18Module())
|
||||
def Resnet18Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 224, 224))
|
|
@ -230,31 +230,3 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
|||
trace=trace,
|
||||
golden_trace=golden_trace))
|
||||
return results
|
||||
|
||||
|
||||
def report_results(results: List[TestResult]):
|
||||
"""Provide a basic error report summarizing various TestResult's."""
|
||||
for result in results:
|
||||
failed = False
|
||||
for item_num, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
assert item.symbol == golden_item.symbol
|
||||
assert len(item.inputs) == len(golden_item.inputs)
|
||||
assert len(item.outputs) == len(golden_item.outputs)
|
||||
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
||||
assert torch.allclose(input, golden_input)
|
||||
for output_num, (output, golden_output) in enumerate(
|
||||
zip(item.outputs, golden_item.outputs)):
|
||||
# TODO: Refine error message. Things to consider:
|
||||
# - Very large tensors -- don't spew, but give useful info
|
||||
# - Smaller tensors / primitives -- want to show exact values
|
||||
# - Machine parseable format?
|
||||
if not torch.allclose(output, golden_output):
|
||||
print(
|
||||
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
||||
)
|
||||
failed = True
|
||||
if failed:
|
||||
print('FAILURE "{}"'.format(result.unique_name))
|
||||
else:
|
||||
print('SUCCESS "{}"'.format(result.unique_name))
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""
|
||||
Utilities for reporting the results of the test framework.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from .framework import TestResult
|
||||
|
||||
class _TensorStats:
|
||||
def __init__(self, tensor):
|
||||
self.min = torch.min(tensor)
|
||||
self.max = torch.max(tensor)
|
||||
self.mean = torch.mean(tensor)
|
||||
def __str__(self):
|
||||
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
||||
|
||||
|
||||
def _print_detailed_tensor_diff(tensor, golden_tensor):
|
||||
if tensor.size() != golden_tensor.size():
|
||||
print(
|
||||
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
||||
)
|
||||
return
|
||||
print('tensor stats : ', _TensorStats(tensor))
|
||||
print('golden tensor stats: ', _TensorStats(golden_tensor))
|
||||
|
||||
def report_results(results: List[TestResult]):
|
||||
"""Provide a basic error report summarizing various TestResult's."""
|
||||
any_failed = False
|
||||
for result in results:
|
||||
failed = False
|
||||
for item_num, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
assert item.symbol == golden_item.symbol
|
||||
assert len(item.inputs) == len(golden_item.inputs)
|
||||
assert len(item.outputs) == len(golden_item.outputs)
|
||||
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
||||
assert torch.allclose(input, golden_input)
|
||||
for output_num, (output, golden_output) in enumerate(
|
||||
zip(item.outputs, golden_item.outputs)):
|
||||
# TODO: Refine error message. Things to consider:
|
||||
# - Very large tensors -- don't spew, but give useful info
|
||||
# - Smaller tensors / primitives -- want to show exact values
|
||||
# - Machine parseable format?
|
||||
if not torch.allclose(output, golden_output):
|
||||
if not failed:
|
||||
print('FAILURE "{}"'.format(result.unique_name))
|
||||
failed = any_failed = True
|
||||
print(
|
||||
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
||||
)
|
||||
_print_detailed_tensor_diff(output, golden_output)
|
||||
if not failed:
|
||||
print('SUCCESS "{}"'.format(result.unique_name))
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||
|
||||
|
@ -26,6 +27,12 @@ def MmModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
# CHECK: SUCCESS "MmModule_basic2"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic2(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
|
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||
|
||||
|
@ -25,8 +26,11 @@ class MmModule(torch.nn.Module):
|
|||
|
||||
|
||||
# TODO: Refine error messages.
|
||||
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
||||
# CHECK: FAILURE "MmModule_basic"
|
||||
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
||||
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}}
|
||||
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}}
|
||||
# CHECK-NOT: ALL PASS
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
|
Loading…
Reference in New Issue