Add vision models (resnet18 to start).

Also,
- improve error reporting of e2e framework.
pull/209/head
Sean Silva 2021-04-20 14:45:43 -07:00
parent 390d39a96c
commit 8f96901943
6 changed files with 108 additions and 33 deletions

View File

@ -4,7 +4,8 @@
import argparse import argparse
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results from torch_mlir.torchscript.e2e_test.framework import run_tests
from torch_mlir.torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
# Available test configs. # Available test configs.
@ -20,6 +21,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
# be run from a specific directory. # be run from a specific directory.
# TODO: Find out best practices for python "main" files. # TODO: Find out best practices for python "main" files.
import basic import basic
import vision_models
def main(): def main():
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
@ -29,7 +31,7 @@ def main():
help=''' help='''
Meaning of options: Meaning of options:
"refbackend": run through npcomp's RefBackend. "refbackend": run through npcomp's RefBackend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic). "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''') ''')
args = parser.parse_args() args = parser.parse_args()

View File

@ -0,0 +1,30 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
import torchvision.models as models
from torch_mlir.torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export
# ==============================================================================
class Resnet18Module(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
torch.manual_seed(0)
self.resnet = models.resnet18()
@export
@annotate_args([
None,
([-1, 3, -1, -1], torch.float32),
])
def forward(self, img):
return self.resnet.forward(img)
@register_test_case(module_factory=lambda: Resnet18Module())
def Resnet18Module_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 224, 224))

View File

@ -230,31 +230,3 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
trace=trace, trace=trace,
golden_trace=golden_trace)) golden_trace=golden_trace))
return results return results
def report_results(results: List[TestResult]):
"""Provide a basic error report summarizing various TestResult's."""
for result in results:
failed = False
for item_num, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
assert item.symbol == golden_item.symbol
assert len(item.inputs) == len(golden_item.inputs)
assert len(item.outputs) == len(golden_item.outputs)
for input, golden_input in zip(item.inputs, golden_item.inputs):
assert torch.allclose(input, golden_input)
for output_num, (output, golden_output) in enumerate(
zip(item.outputs, golden_item.outputs)):
# TODO: Refine error message. Things to consider:
# - Very large tensors -- don't spew, but give useful info
# - Smaller tensors / primitives -- want to show exact values
# - Machine parseable format?
if not torch.allclose(output, golden_output):
print(
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
)
failed = True
if failed:
print('FAILURE "{}"'.format(result.unique_name))
else:
print('SUCCESS "{}"'.format(result.unique_name))

View File

@ -0,0 +1,60 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Utilities for reporting the results of the test framework.
"""
from typing import List
import torch
from .framework import TestResult
class _TensorStats:
def __init__(self, tensor):
self.min = torch.min(tensor)
self.max = torch.max(tensor)
self.mean = torch.mean(tensor)
def __str__(self):
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
def _print_detailed_tensor_diff(tensor, golden_tensor):
if tensor.size() != golden_tensor.size():
print(
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
)
return
print('tensor stats : ', _TensorStats(tensor))
print('golden tensor stats: ', _TensorStats(golden_tensor))
def report_results(results: List[TestResult]):
"""Provide a basic error report summarizing various TestResult's."""
any_failed = False
for result in results:
failed = False
for item_num, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
assert item.symbol == golden_item.symbol
assert len(item.inputs) == len(golden_item.inputs)
assert len(item.outputs) == len(golden_item.outputs)
for input, golden_input in zip(item.inputs, golden_item.inputs):
assert torch.allclose(input, golden_input)
for output_num, (output, golden_output) in enumerate(
zip(item.outputs, golden_item.outputs)):
# TODO: Refine error message. Things to consider:
# - Very large tensors -- don't spew, but give useful info
# - Smaller tensors / primitives -- want to show exact values
# - Machine parseable format?
if not torch.allclose(output, golden_output):
if not failed:
print('FAILURE "{}"'.format(result.unique_name))
failed = any_failed = True
print(
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
)
_print_detailed_tensor_diff(output, golden_output)
if not failed:
print('SUCCESS "{}"'.format(result.unique_name))

View File

@ -6,7 +6,8 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
@ -26,6 +27,12 @@ def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))
# CHECK: SUCCESS "MmModule_basic2"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic2(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
def main(): def main():
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config) results = run_tests(GLOBAL_TEST_REGISTRY, config)

View File

@ -6,7 +6,8 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
@ -25,8 +26,11 @@ class MmModule(torch.nn.Module):
# TODO: Refine error messages. # TODO: Refine error messages.
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
# CHECK: FAILURE "MmModule_basic" # CHECK: FAILURE "MmModule_basic"
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}}
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}}
# CHECK-NOT: ALL PASS
@register_test_case(module_factory=lambda: MmModule()) @register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils): def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))