mirror of https://github.com/llvm/torch-mlir
Add vision models (resnet18 to start).
Also, - improve error reporting of e2e framework.pull/209/head
parent
390d39a96c
commit
8f96901943
|
@ -4,7 +4,8 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results
|
from torch_mlir.torchscript.e2e_test.framework import run_tests
|
||||||
|
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
# Available test configs.
|
# Available test configs.
|
||||||
|
@ -20,6 +21,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
|
||||||
# be run from a specific directory.
|
# be run from a specific directory.
|
||||||
# TODO: Find out best practices for python "main" files.
|
# TODO: Find out best practices for python "main" files.
|
||||||
import basic
|
import basic
|
||||||
|
import vision_models
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
|
@ -29,7 +31,7 @@ def main():
|
||||||
help='''
|
help='''
|
||||||
Meaning of options:
|
Meaning of options:
|
||||||
"refbackend": run through npcomp's RefBackend.
|
"refbackend": run through npcomp's RefBackend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
''')
|
''')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
||||||
|
from torch_mlir.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Resnet18Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Reset seed to make model deterministic.
|
||||||
|
torch.manual_seed(0)
|
||||||
|
self.resnet = models.resnet18()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, 3, -1, -1], torch.float32),
|
||||||
|
])
|
||||||
|
def forward(self, img):
|
||||||
|
return self.resnet.forward(img)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Resnet18Module())
|
||||||
|
def Resnet18Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 3, 224, 224))
|
|
@ -230,31 +230,3 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||||
trace=trace,
|
trace=trace,
|
||||||
golden_trace=golden_trace))
|
golden_trace=golden_trace))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def report_results(results: List[TestResult]):
|
|
||||||
"""Provide a basic error report summarizing various TestResult's."""
|
|
||||||
for result in results:
|
|
||||||
failed = False
|
|
||||||
for item_num, (item, golden_item) in enumerate(
|
|
||||||
zip(result.trace, result.golden_trace)):
|
|
||||||
assert item.symbol == golden_item.symbol
|
|
||||||
assert len(item.inputs) == len(golden_item.inputs)
|
|
||||||
assert len(item.outputs) == len(golden_item.outputs)
|
|
||||||
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
|
||||||
assert torch.allclose(input, golden_input)
|
|
||||||
for output_num, (output, golden_output) in enumerate(
|
|
||||||
zip(item.outputs, golden_item.outputs)):
|
|
||||||
# TODO: Refine error message. Things to consider:
|
|
||||||
# - Very large tensors -- don't spew, but give useful info
|
|
||||||
# - Smaller tensors / primitives -- want to show exact values
|
|
||||||
# - Machine parseable format?
|
|
||||||
if not torch.allclose(output, golden_output):
|
|
||||||
print(
|
|
||||||
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
|
||||||
)
|
|
||||||
failed = True
|
|
||||||
if failed:
|
|
||||||
print('FAILURE "{}"'.format(result.unique_name))
|
|
||||||
else:
|
|
||||||
print('SUCCESS "{}"'.format(result.unique_name))
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utilities for reporting the results of the test framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .framework import TestResult
|
||||||
|
|
||||||
|
class _TensorStats:
|
||||||
|
def __init__(self, tensor):
|
||||||
|
self.min = torch.min(tensor)
|
||||||
|
self.max = torch.max(tensor)
|
||||||
|
self.mean = torch.mean(tensor)
|
||||||
|
def __str__(self):
|
||||||
|
return f'min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
|
||||||
|
|
||||||
|
|
||||||
|
def _print_detailed_tensor_diff(tensor, golden_tensor):
|
||||||
|
if tensor.size() != golden_tensor.size():
|
||||||
|
print(
|
||||||
|
f'Tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
print('tensor stats : ', _TensorStats(tensor))
|
||||||
|
print('golden tensor stats: ', _TensorStats(golden_tensor))
|
||||||
|
|
||||||
|
def report_results(results: List[TestResult]):
|
||||||
|
"""Provide a basic error report summarizing various TestResult's."""
|
||||||
|
any_failed = False
|
||||||
|
for result in results:
|
||||||
|
failed = False
|
||||||
|
for item_num, (item, golden_item) in enumerate(
|
||||||
|
zip(result.trace, result.golden_trace)):
|
||||||
|
assert item.symbol == golden_item.symbol
|
||||||
|
assert len(item.inputs) == len(golden_item.inputs)
|
||||||
|
assert len(item.outputs) == len(golden_item.outputs)
|
||||||
|
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
||||||
|
assert torch.allclose(input, golden_input)
|
||||||
|
for output_num, (output, golden_output) in enumerate(
|
||||||
|
zip(item.outputs, golden_item.outputs)):
|
||||||
|
# TODO: Refine error message. Things to consider:
|
||||||
|
# - Very large tensors -- don't spew, but give useful info
|
||||||
|
# - Smaller tensors / primitives -- want to show exact values
|
||||||
|
# - Machine parseable format?
|
||||||
|
if not torch.allclose(output, golden_output):
|
||||||
|
if not failed:
|
||||||
|
print('FAILURE "{}"'.format(result.unique_name))
|
||||||
|
failed = any_failed = True
|
||||||
|
print(
|
||||||
|
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
||||||
|
)
|
||||||
|
_print_detailed_tensor_diff(output, golden_output)
|
||||||
|
if not failed:
|
||||||
|
print('SUCCESS "{}"'.format(result.unique_name))
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
@ -26,6 +27,12 @@ def MmModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK: SUCCESS "MmModule_basic2"
|
||||||
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
|
def MmModule_basic2(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
|
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
@ -25,8 +26,11 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# TODO: Refine error messages.
|
# TODO: Refine error messages.
|
||||||
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
|
||||||
# CHECK: FAILURE "MmModule_basic"
|
# CHECK: FAILURE "MmModule_basic"
|
||||||
|
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
||||||
|
# CHECK: tensor stats : min={{.*}}, max={{.*}}, mean={{.*}}
|
||||||
|
# CHECK: golden tensor stats: min={{.*}}, max={{.*}}, mean={{.*}}
|
||||||
|
# CHECK-NOT: ALL PASS
|
||||||
@register_test_case(module_factory=lambda: MmModule())
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
def MmModule_basic(module, tu: TestUtils):
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
Loading…
Reference in New Issue