mirror of https://github.com/llvm/torch-mlir
Add end-to-end testing framework for TorchScript.
The E2E tests can be run with ``` npcpy frontends/pytorch/e2e_testing/torchscript/main.py ``` This commit adds a couple items supporting that end, including new sugar for annotations (no more raw use of ClassAnnotator!). Recommended review order: 1. `frontends/pytorch/e2e_testing/torchscript/main.py` for the harness + `basic.py` in that directory for examples of tests. 2. Annotation sugar in `frontends/pytorch/python/torch_mlir/torchscript/annotations.py` and unittest in `frontends/pytorch/test/ivalue_import/annotations/sugar.py` 3. Global test registry / sugar in `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py` 4. `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py` for the meat of the testing framework (start at `run_tests`), and looking at the backend configs in `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs` for examples of backends. This is likely the bulk of review time. 5. Unit tests of the framework logic in `frontends/pytorch/test/torchscript_e2e_test` There's TODO's scattered throughout, but this seems functional enough to start pulling stuff into and kicking the tires. A few missing pieces: 1. Marking test expected pass/fail per backend. 2. Figuring out how best to fit this into dev workflows. 3. IREE TestConfig. Also, forgive this Python newbie... Any advice on Python code structure / library design would be much appreciated.pull/208/head
parent
fef1733e12
commit
39d50ccf0d
|
@ -0,0 +1,69 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
||||||
|
from torch_mlir.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MmModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32),
|
||||||
|
([-1, -1], torch.float32),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
|
def MmModule_chained(module, tu: TestUtils):
|
||||||
|
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
module.forward(res, res)
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class TanhModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 3, -1], torch.float32),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.tanh(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TanhModule())
|
||||||
|
def TanhModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3, 1))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MmTanhModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32),
|
||||||
|
([-1, -1], torch.float32),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.tanh(self.matmul(lhs, rhs))
|
||||||
|
def matmul(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MmTanhModule())
|
||||||
|
def MmTanhModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 2), tu.rand(2, 4))
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
|
# Available test configs.
|
||||||
|
from torch_mlir.torchscript.e2e_test.configs import (
|
||||||
|
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import tests to register them in the global registry.
|
||||||
|
import basic
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
|
parser.add_argument('--config',
|
||||||
|
choices=['native_torch', 'torchscript', 'refbackend'],
|
||||||
|
default='refbackend',
|
||||||
|
help='''
|
||||||
|
Meaning of options:
|
||||||
|
"refbackend": run through npcomp's RefBackend.
|
||||||
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic).
|
||||||
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
|
''')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.config == 'refbackend':
|
||||||
|
config = RefBackendTestConfig()
|
||||||
|
elif args.config == 'native_torch':
|
||||||
|
config = NativeTorchTestConfig()
|
||||||
|
elif args.config == 'torchscript':
|
||||||
|
config = TorchScriptTestConfig()
|
||||||
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
|
report_results(results)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,95 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# Decorators
|
||||||
|
|
||||||
|
# Currently, these decorators are very low-level and map 1:1 with
|
||||||
|
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
|
||||||
|
# be a more elaborate Python layer which allows all the different annotations
|
||||||
|
# to be expressed conveniently and gives clearer error reports when
|
||||||
|
# the annotations aren't acceptable.
|
||||||
|
|
||||||
|
|
||||||
|
def export(fn):
|
||||||
|
"""Decorator that tells the npcomp compiler that a method is exported.
|
||||||
|
|
||||||
|
By default, no methods are exported, which is very important for
|
||||||
|
the compiler, because otherwise most Torch programs consist of a sea
|
||||||
|
of tiny exported functions with no rank or dtype information
|
||||||
|
(see `annotate_args`), which the compiler cannot do much with.
|
||||||
|
|
||||||
|
Note that this is different from `torch.jit.export`, which controls
|
||||||
|
which methods are scripted in the first place. For non-`forward` methods,
|
||||||
|
using this decorator usually means you also need `torch.jit.export`.
|
||||||
|
Conceptually, this decorator is annotating the scripted module, but is
|
||||||
|
applied to the original `torch.nn.Module` for convenience.
|
||||||
|
"""
|
||||||
|
fn._npcomp_export = True
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
ArgAnnotation = Tuple[List[int], torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Replace with py3 extended argument annotations when available.
|
||||||
|
# See https://www.python.org/dev/peps/pep-0593/
|
||||||
|
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
||||||
|
"""Decorator that tells the npcomp compiler information about arguments.
|
||||||
|
|
||||||
|
The `annotations` should be a list of the same length as the number of
|
||||||
|
argument to the method (including `self`). Each list entry is either:
|
||||||
|
- None, corresponding to providing the compiler with no information.
|
||||||
|
- A 2-tuple consisting of a shape and a dtype, such as
|
||||||
|
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
|
||||||
|
indicated by using `-1` as the size. This provides the compiler a
|
||||||
|
guarantee that the argument will always dynamically have the described
|
||||||
|
shape and dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Check the number of arguments matches the number of arg annotations.
|
||||||
|
def decorator(fn):
|
||||||
|
fn._npcomp_arg_annotations = annotations
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# Utilities for extracting decorated information into torch_mlir.ClassAnnotator.
|
||||||
|
|
||||||
|
|
||||||
|
def _recursively_extract_annotations(
|
||||||
|
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
|
||||||
|
class_annotator: torch_mlir.ClassAnnotator):
|
||||||
|
assert module.__class__.__name__ == scripted.original_name, "script module does not come from specified module"
|
||||||
|
|
||||||
|
# Extract information on methods.
|
||||||
|
for method_name, scripted_method in scripted.__dict__.items():
|
||||||
|
if not isinstance(scripted_method, torch.ScriptMethod):
|
||||||
|
continue
|
||||||
|
method = getattr(module, method_name)
|
||||||
|
if hasattr(method, '_npcomp_export'):
|
||||||
|
class_annotator.exportPath(scripted._c._type(), [method_name])
|
||||||
|
if hasattr(method, '_npcomp_arg_annotations'):
|
||||||
|
class_annotator.annotateShapesAndDtypes(
|
||||||
|
scripted._c._type(), [method_name],
|
||||||
|
method._npcomp_arg_annotations)
|
||||||
|
# Recurse.
|
||||||
|
for name, child in module.named_children():
|
||||||
|
scripted_child = getattr(scripted, name)
|
||||||
|
_recursively_extract_annotations(child, scripted_child,
|
||||||
|
class_annotator)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_annotations(program: torch.nn.Module,
|
||||||
|
scripted: torch.jit.ScriptModule,
|
||||||
|
class_annotator: torch_mlir.ClassAnnotator):
|
||||||
|
"""Populate the ClassAnnotator with annotations extracted from `program`."""
|
||||||
|
class_annotator.exportNone(scripted._c._type())
|
||||||
|
_recursively_extract_annotations(program, scripted, class_annotator)
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from .ref_backend import RefBackendTestConfig
|
||||||
|
from .native_torch import NativeTorchTestConfig
|
||||||
|
from .torchscript import TorchScriptTestConfig
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
|
|
||||||
|
class NativeTorchTestConfig(TestConfig):
|
||||||
|
"""TestConfig that just runs the torch.nn.Module without compiling"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
return program
|
||||||
|
|
||||||
|
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||||
|
# TODO: Deepcopy the torch.nn.Module, so that if the program is
|
||||||
|
# stateful then it does not mutate the original compiled program.
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
outputs = getattr(artifact, item.symbol)(*item.inputs)
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = [outputs]
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
outputs=outputs))
|
||||||
|
return result
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch_mlir
|
||||||
|
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
from torch_mlir.torchscript.annotations import extract_annotations
|
||||||
|
|
||||||
|
|
||||||
|
class RefBackendTestConfig(TestConfig):
|
||||||
|
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.backend = refjit.CompilerBackend()
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
scripted = torch.jit.script(program)
|
||||||
|
class_annotator = torch_mlir.ClassAnnotator()
|
||||||
|
|
||||||
|
extract_annotations(program, scripted, class_annotator)
|
||||||
|
|
||||||
|
mb.import_module(scripted._c, class_annotator)
|
||||||
|
# Lower module in place.
|
||||||
|
frontend_lowering.lower_object_graph(mb.module)
|
||||||
|
return self.backend.compile(mb.module)
|
||||||
|
|
||||||
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
|
jit_module = self.backend.load(artifact)
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
numpy_inputs = [t.numpy() for t in item.inputs]
|
||||||
|
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
|
||||||
|
if isinstance(outputs, np.ndarray):
|
||||||
|
outputs = [outputs]
|
||||||
|
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
outputs=torch_outputs))
|
||||||
|
return result
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
|
|
||||||
|
class TorchScriptTestConfig(TestConfig):
|
||||||
|
"""TestConfig that runs the torch.nn.Module through TorchScript"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> torch.jit.ScriptModule:
|
||||||
|
return torch.jit.script(program)
|
||||||
|
|
||||||
|
def run(self, artifact: torch.jit.ScriptModule, trace: Trace) -> Trace:
|
||||||
|
# TODO: Deepcopy the torch.jit.ScriptModule, so that if the program is
|
||||||
|
# stateful then it does not mutate the original compiled program.
|
||||||
|
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
outputs = getattr(artifact, item.symbol)(*item.inputs)
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = [outputs]
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
outputs=outputs))
|
||||||
|
return result
|
|
@ -0,0 +1,260 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
"""
|
||||||
|
# End-to-end testing framework for TorchScript.
|
||||||
|
|
||||||
|
For the purposes of this framework, "end to end" means the first "end" is
|
||||||
|
a `torch.nn.Module`, and the second "end" is execution.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
A program for this testing framework is considered to be a `torch.nn.Module`,
|
||||||
|
which has a public interface consisting of its methods and instance attributes.
|
||||||
|
|
||||||
|
A test in the framework consists conceputally of a list of calls into
|
||||||
|
the methods of a module (TODO: extend to instance attributes). It is expected
|
||||||
|
that the outputs match between the program run on a backend (controlled by
|
||||||
|
a TestConfig) and a golden trace obtained by running on native Torch (without
|
||||||
|
compiling or TorchScript'ing).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import Any, Callable, List, NamedTuple, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TraceItem(NamedTuple):
|
||||||
|
# The externally visible symbol name that is called.
|
||||||
|
# For example `"forward"` or `"submodule.forward"`.
|
||||||
|
symbol: str
|
||||||
|
# The list of inputs to the call.
|
||||||
|
inputs: List[torch.Tensor] # TODO: Support more types.
|
||||||
|
# The outputs from the call.
|
||||||
|
# Sometimes this field is treated as golden outputs from a test.
|
||||||
|
# Sometimes this field is treated as ignored, such as the input trace
|
||||||
|
# provided to `TestConfig.run`.
|
||||||
|
outputs: List[torch.Tensor] # TODO: Support more types
|
||||||
|
|
||||||
|
|
||||||
|
# A trace of invocations to the program.
|
||||||
|
# This is an ordered sequence of external invocations to a program's
|
||||||
|
# public boundary.
|
||||||
|
Trace = List[TraceItem]
|
||||||
|
|
||||||
|
# A type shared between the result of `TestConfig.compile` and the input
|
||||||
|
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||||
|
# this type.
|
||||||
|
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(abc.ABC):
|
||||||
|
"""The interface implemented by backends to run tests.
|
||||||
|
|
||||||
|
The testing framework expects to be able to call `compile` to compile
|
||||||
|
a torch.nn.Module, and then pass the compiled artifact to `run` to run it.
|
||||||
|
|
||||||
|
Note that the definition of "compiled artifact" here is quite loose, and
|
||||||
|
this interface allows for many different use cases besides simple testing.
|
||||||
|
|
||||||
|
For example, this interface can be overridden to be a "data collector"
|
||||||
|
to gather information across all the test cases. For example,
|
||||||
|
a compiler backend could override "compile" to just return some IR at a
|
||||||
|
useful intermediate abstraction level (rather than the final compiled
|
||||||
|
artifact), and then have "run" save this intermediate IR + the trace as
|
||||||
|
input to some lower-level software stack's testing format.
|
||||||
|
|
||||||
|
The set of TestConfig's is expected to be pluggable and provided by
|
||||||
|
users to suit their own needs. We provide a few configs out of the box
|
||||||
|
in the `configs` submodule of this package, but those are intended
|
||||||
|
to be for basic inspiration and enough for our own testing.
|
||||||
|
Backends to npcomp will likely have more elaborate TestConfig's, such
|
||||||
|
as `compile` being "compile for such-and-such DSP with these vectorization
|
||||||
|
cost model flags" and `run` being "connect to Android phone with
|
||||||
|
device ID 1234 and upload a program to run on it's DSP core, and also set
|
||||||
|
power throttling settings to 'performance'".
|
||||||
|
|
||||||
|
That is also why this class is not called "backend", as it
|
||||||
|
encapsulates potentially many specific details of the test configuration
|
||||||
|
process as well. There isn't a general way to disentangle test configuration
|
||||||
|
from the compile/run process specific to a logical backend, since each
|
||||||
|
backend (compiler backend and runtime target) will have an arbitrarily
|
||||||
|
wild and wonderful set of possible configurations that we cannot predict.
|
||||||
|
"""
|
||||||
|
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
|
||||||
|
# We can have a helper class NpcompBackendTestConfig which does that.
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compile(self, program: torch.nn.Module) -> CompiledArtifact:
|
||||||
|
"""Compile the provided torch.nn.Module into a compiled artifact"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Any should match result of `compile`.
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def run(self, artifact: CompiledArtifact, trace: Trace) -> Trace:
|
||||||
|
"""Run the compiled artifact produced by `compile`.
|
||||||
|
|
||||||
|
The backend should load the compiled artifact and call the
|
||||||
|
symbol names listed in `trace` with their respective inputs (the outputs
|
||||||
|
of `trace` should be ignored). A new identical trace with outputs
|
||||||
|
populated should be returned.
|
||||||
|
|
||||||
|
This method should assume that `artifact` is being shared with
|
||||||
|
multiple parallel invocations of `run`, and so it should not be mutated.
|
||||||
|
This property is typicaly trivially satisfied for a true
|
||||||
|
"compiled artifact", but some backends don't directly involve a
|
||||||
|
compiled artifact per se (like a backend for which `CompiledArtifact` is
|
||||||
|
`torch.nn.Module` and `run` just invokes the torch.nn.Module itself)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
artifact: a compiled artifact produced by `compile`.
|
||||||
|
trace: The external invocations to stimulate the module.
|
||||||
|
Returns:
|
||||||
|
A trace with outputs recorded according to the results of running
|
||||||
|
on this backend.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Utilities for common testing trace generation.
|
||||||
|
# Also, resets the random seed for reproducibility.
|
||||||
|
# TODO: If generating in parallel, how to have manual_seed be local?
|
||||||
|
class TestUtils:
|
||||||
|
"""Utilities for executing a test.
|
||||||
|
|
||||||
|
Test cases are provided an instance of this class to make test cases
|
||||||
|
more succinct.
|
||||||
|
|
||||||
|
For reproducibility, this class also resets the random seed.
|
||||||
|
TODO: Figure out how to seed reset properly scoped to just a test case
|
||||||
|
(such as when running tests in parallel)
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# TODO: Add zeros/ones/etc. as convenient.
|
||||||
|
def rand(self, *sizes):
|
||||||
|
if len(sizes) == 0:
|
||||||
|
return torch.rand([])
|
||||||
|
return torch.rand(*sizes)
|
||||||
|
|
||||||
|
|
||||||
|
class Test(NamedTuple):
|
||||||
|
"""A description of a test as produced by the test frontend.
|
||||||
|
"""
|
||||||
|
# Stable name for error reporting.
|
||||||
|
#
|
||||||
|
# This name's stability is also useful for backend, which want to
|
||||||
|
# generate their own lower-level test suites based on this framework.
|
||||||
|
#
|
||||||
|
# It is expected that those backends will need additional
|
||||||
|
# metadata to describe their test configurations, so having a unique
|
||||||
|
# key to keep that information associated is important.
|
||||||
|
unique_name: str
|
||||||
|
# A callable which produces the module under test.
|
||||||
|
# This is a callable to allow lazily creating the module.
|
||||||
|
program_factory: Callable[[], torch.nn.Module]
|
||||||
|
# A callable which provides external stimuli to the module.
|
||||||
|
# The first parameter is a torch.nn.Module (or a `_Tracer` wrapping that
|
||||||
|
# module, actually).
|
||||||
|
# The secon parameter is a `TestUtils` instance for convenience.
|
||||||
|
program_invoker: Callable[[Any, TestUtils], None]
|
||||||
|
|
||||||
|
|
||||||
|
class TestResult(NamedTuple):
|
||||||
|
# Stable unique name for error reporting and test suite configuration.
|
||||||
|
#
|
||||||
|
# Tests frequently need some additional data (such as expected pass/fail
|
||||||
|
# status, desired test configurations, etc.), and this gives a key to
|
||||||
|
# associate to. This avoids extending this class arbitrarily for every
|
||||||
|
# possible requirement from the test framework.
|
||||||
|
#
|
||||||
|
# This name is also useful for backends that are generating their own
|
||||||
|
# lower-level test suites from this framework for the same reasons, though
|
||||||
|
# those reasons are stronger because we cannot simply extend this
|
||||||
|
# class.
|
||||||
|
unique_name: str # Should match Test.unique_name for corresponding test.
|
||||||
|
# The trace produced by the backend.
|
||||||
|
trace: Trace
|
||||||
|
# The golden trace which `trace` is expected to match.
|
||||||
|
golden_trace: Trace
|
||||||
|
|
||||||
|
|
||||||
|
class _Tracer:
|
||||||
|
"""Wrapper around a `torch.nn.Module` that records calls into it.
|
||||||
|
|
||||||
|
The inputs and outputs of each call are recorded in a Trace.
|
||||||
|
"""
|
||||||
|
module: torch.nn.Module
|
||||||
|
trace: Trace
|
||||||
|
|
||||||
|
def __init__(self, module: torch.nn.Module):
|
||||||
|
self.module = module
|
||||||
|
self.trace = []
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
# TODO: Handle `module.foo.bar.baz` nesting.
|
||||||
|
# For now, we are limited to attributes of the top-level module.
|
||||||
|
def invoke(*args):
|
||||||
|
raw_outputs = getattr(self.module, name)(*args)
|
||||||
|
if isinstance(raw_outputs, torch.Tensor):
|
||||||
|
outputs = [raw_outputs]
|
||||||
|
self.trace.append(
|
||||||
|
TraceItem(symbol=name, inputs=args, outputs=outputs))
|
||||||
|
return raw_outputs
|
||||||
|
|
||||||
|
return invoke
|
||||||
|
|
||||||
|
def get_trace(self):
|
||||||
|
return self.trace
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_golden_trace(test: Test) -> Trace:
|
||||||
|
tracer = _Tracer(test.program_factory())
|
||||||
|
test.program_invoker(tracer, TestUtils())
|
||||||
|
return tracer.get_trace()
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||||
|
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||||
|
results = []
|
||||||
|
for test in tests:
|
||||||
|
golden_trace = _generate_golden_trace(test)
|
||||||
|
# TODO: Precompile everything in parallel.
|
||||||
|
compiled = config.compile(test.program_factory())
|
||||||
|
# TODO: Run in parallel.
|
||||||
|
trace = config.run(compiled, golden_trace)
|
||||||
|
results.append(
|
||||||
|
TestResult(unique_name=test.unique_name,
|
||||||
|
trace=trace,
|
||||||
|
golden_trace=golden_trace))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def report_results(results: List[TestResult]):
|
||||||
|
"""Provide a basic error report summarizing various TestResult's."""
|
||||||
|
for result in results:
|
||||||
|
failed = False
|
||||||
|
for item_num, (item, golden_item) in enumerate(
|
||||||
|
zip(result.trace, result.golden_trace)):
|
||||||
|
assert item.symbol == golden_item.symbol
|
||||||
|
assert len(item.inputs) == len(golden_item.inputs)
|
||||||
|
assert len(item.outputs) == len(golden_item.outputs)
|
||||||
|
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
||||||
|
assert torch.allclose(input, golden_input)
|
||||||
|
for output_num, (output, golden_output) in enumerate(
|
||||||
|
zip(item.outputs, golden_item.outputs)):
|
||||||
|
# TODO: Refine error message. Things to consider:
|
||||||
|
# - Very large tensors -- don't spew, but give useful info
|
||||||
|
# - Smaller tensors / primitives -- want to show exact values
|
||||||
|
# - Machine parseable format?
|
||||||
|
if not torch.allclose(output, golden_output):
|
||||||
|
print(
|
||||||
|
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
||||||
|
)
|
||||||
|
failed = True
|
||||||
|
if failed:
|
||||||
|
print('FAILURE "{}"'.format(result.unique_name))
|
||||||
|
else:
|
||||||
|
print('SUCCESS "{}"'.format(result.unique_name))
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .framework import Test
|
||||||
|
|
||||||
|
# The global registry of tests.
|
||||||
|
GLOBAL_TEST_REGISTRY = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
||||||
|
"""Convenient decorator-based test registration.
|
||||||
|
|
||||||
|
Adds a `framework.Test` to the global test registry based on the decorated
|
||||||
|
function. The test's `unique_name` is taken from the function name, the
|
||||||
|
test's `program_factory` is taken from `module_factory`, and the
|
||||||
|
`program_invoker` is the decorated function.
|
||||||
|
"""
|
||||||
|
def decorator(f):
|
||||||
|
GLOBAL_TEST_REGISTRY.append(
|
||||||
|
Test(unique_name=f.__name__,
|
||||||
|
program_factory=module_factory,
|
||||||
|
program_invoker=f))
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
|
@ -0,0 +1,50 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch_mlir
|
||||||
|
from torch_mlir.torchscript.annotations import (
|
||||||
|
annotate_args, export, extract_annotations
|
||||||
|
)
|
||||||
|
|
||||||
|
class MmModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.float32),
|
||||||
|
([4, 5], torch.float32),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
module = MmModule()
|
||||||
|
annotator = torch_mlir.ClassAnnotator()
|
||||||
|
extract_annotations(module, torch.jit.script(module), annotator)
|
||||||
|
print(annotator)
|
||||||
|
|
||||||
|
# CHECK: ClassAnnotator {
|
||||||
|
# CHECK: ClassAnnotation('__torch__.MmModule') {
|
||||||
|
# CHECK: MethodAnnotation('forward') {
|
||||||
|
# CHECK: isExported = true
|
||||||
|
# CHECK: argAnnotations =
|
||||||
|
# CHECK: ArgAnnotation(0) {
|
||||||
|
# CHECK: dtype = <none>
|
||||||
|
# CHECK: shape = <none>
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: ArgAnnotation(1) {
|
||||||
|
# CHECK: dtype = Float
|
||||||
|
# CHECK: shape = [3, 4]
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: ArgAnnotation(2) {
|
||||||
|
# CHECK: dtype = Float
|
||||||
|
# CHECK: shape = [4, 5]
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: }
|
|
@ -0,0 +1,2 @@
|
||||||
|
This directory is for testing the e2e_test framework itself.
|
||||||
|
It is not for holding e2e tests themselves!!!
|
|
@ -0,0 +1,36 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
|
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MmModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Refine messages.
|
||||||
|
# CHECK: SUCCESS "MmModule_basic"
|
||||||
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = TorchScriptTestConfig()
|
||||||
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
|
report_results(results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,42 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
|
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MmModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
# Use torch.jit.is_scripting() to fake a miscompile.
|
||||||
|
# The non-scripted code will take one path, and the scripted code
|
||||||
|
# will take another path.
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
return torch.mm(rhs, lhs)
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Refine error messages.
|
||||||
|
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
||||||
|
# CHECK: FAILURE "MmModule_basic"
|
||||||
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = TorchScriptTestConfig()
|
||||||
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
|
report_results(results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue