mirror of https://github.com/llvm/torch-mlir
Add end-to-end testing framework for TorchScript.
The E2E tests can be run with ``` npcpy frontends/pytorch/e2e_testing/torchscript/main.py ``` This commit adds a couple items supporting that end, including new sugar for annotations (no more raw use of ClassAnnotator!). Recommended review order: 1. `frontends/pytorch/e2e_testing/torchscript/main.py` for the harness + `basic.py` in that directory for examples of tests. 2. Annotation sugar in `frontends/pytorch/python/torch_mlir/torchscript/annotations.py` and unittest in `frontends/pytorch/test/ivalue_import/annotations/sugar.py` 3. Global test registry / sugar in `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py` 4. `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py` for the meat of the testing framework (start at `run_tests`), and looking at the backend configs in `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs` for examples of backends. This is likely the bulk of review time. 5. Unit tests of the framework logic in `frontends/pytorch/test/torchscript_e2e_test` There's TODO's scattered throughout, but this seems functional enough to start pulling stuff into and kicking the tires. A few missing pieces: 1. Marking test expected pass/fail per backend. 2. Figuring out how best to fit this into dev workflows. 3. IREE TestConfig. Also, forgive this Python newbie... Any advice on Python code structure / library design would be much appreciated.pull/208/head
parent
fef1733e12
commit
39d50ccf0d
|
@ -0,0 +1,69 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_chained(module, tu: TestUtils):
|
||||
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
module.forward(res, res)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, -1], torch.float32),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.tanh(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: TanhModule())
|
||||
def TanhModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 1))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MmTanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32),
|
||||
([-1, -1], torch.float32),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.tanh(self.matmul(lhs, rhs))
|
||||
def matmul(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
@register_test_case(module_factory=lambda: MmTanhModule())
|
||||
def MmTanhModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2), tu.rand(2, 4))
|
|
@ -0,0 +1,40 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import argparse
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results
|
||||
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||
|
||||
# Available test configs.
|
||||
from torch_mlir.torchscript.e2e_test.configs import (
|
||||
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||
)
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
import basic
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('--config',
|
||||
choices=['native_torch', 'torchscript', 'refbackend'],
|
||||
default='refbackend',
|
||||
help='''
|
||||
Meaning of options:
|
||||
"refbackend": run through npcomp's RefBackend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
''')
|
||||
args = parser.parse_args()
|
||||
if args.config == 'refbackend':
|
||||
config = RefBackendTestConfig()
|
||||
elif args.config == 'native_torch':
|
||||
config = NativeTorchTestConfig()
|
||||
elif args.config == 'torchscript':
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,95 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
|
||||
# Decorators
|
||||
|
||||
# Currently, these decorators are very low-level and map 1:1 with
|
||||
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
|
||||
# be a more elaborate Python layer which allows all the different annotations
|
||||
# to be expressed conveniently and gives clearer error reports when
|
||||
# the annotations aren't acceptable.
|
||||
|
||||
|
||||
def export(fn):
|
||||
"""Decorator that tells the npcomp compiler that a method is exported.
|
||||
|
||||
By default, no methods are exported, which is very important for
|
||||
the compiler, because otherwise most Torch programs consist of a sea
|
||||
of tiny exported functions with no rank or dtype information
|
||||
(see `annotate_args`), which the compiler cannot do much with.
|
||||
|
||||
Note that this is different from `torch.jit.export`, which controls
|
||||
which methods are scripted in the first place. For non-`forward` methods,
|
||||
using this decorator usually means you also need `torch.jit.export`.
|
||||
Conceptually, this decorator is annotating the scripted module, but is
|
||||
applied to the original `torch.nn.Module` for convenience.
|
||||
"""
|
||||
fn._npcomp_export = True
|
||||
return fn
|
||||
|
||||
|
||||
ArgAnnotation = Tuple[List[int], torch.dtype]
|
||||
|
||||
|
||||
# TODO: Replace with py3 extended argument annotations when available.
|
||||
# See https://www.python.org/dev/peps/pep-0593/
|
||||
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
||||
"""Decorator that tells the npcomp compiler information about arguments.
|
||||
|
||||
The `annotations` should be a list of the same length as the number of
|
||||
argument to the method (including `self`). Each list entry is either:
|
||||
- None, corresponding to providing the compiler with no information.
|
||||
- A 2-tuple consisting of a shape and a dtype, such as
|
||||
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
|
||||
indicated by using `-1` as the size. This provides the compiler a
|
||||
guarantee that the argument will always dynamically have the described
|
||||
shape and dtype.
|
||||
"""
|
||||
|
||||
# TODO: Check the number of arguments matches the number of arg annotations.
|
||||
def decorator(fn):
|
||||
fn._npcomp_arg_annotations = annotations
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Utilities for extracting decorated information into torch_mlir.ClassAnnotator.
|
||||
|
||||
|
||||
def _recursively_extract_annotations(
|
||||
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
|
||||
class_annotator: torch_mlir.ClassAnnotator):
|
||||
assert module.__class__.__name__ == scripted.original_name, "script module does not come from specified module"
|
||||
|
||||
# Extract information on methods.
|
||||
for method_name, scripted_method in scripted.__dict__.items():
|
||||
if not isinstance(scripted_method, torch.ScriptMethod):
|
||||
continue
|
||||
method = getattr(module, method_name)
|
||||
if hasattr(method, '_npcomp_export'):
|
||||
class_annotator.exportPath(scripted._c._type(), [method_name])
|
||||
if hasattr(method, '_npcomp_arg_annotations'):
|
||||
class_annotator.annotateShapesAndDtypes(
|
||||
scripted._c._type(), [method_name],
|
||||
method._npcomp_arg_annotations)
|
||||
# Recurse.
|
||||
for name, child in module.named_children():
|
||||
scripted_child = getattr(scripted, name)
|
||||
_recursively_extract_annotations(child, scripted_child,
|
||||
class_annotator)
|
||||
|
||||
|
||||
def extract_annotations(program: torch.nn.Module,
|
||||
scripted: torch.jit.ScriptModule,
|
||||
class_annotator: torch_mlir.ClassAnnotator):
|
||||
"""Populate the ClassAnnotator with annotations extracted from `program`."""
|
||||
class_annotator.exportNone(scripted._c._type())
|
||||
_recursively_extract_annotations(program, scripted, class_annotator)
|
|
@ -0,0 +1,7 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .ref_backend import RefBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
|
@ -0,0 +1,33 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
class NativeTorchTestConfig(TestConfig):
|
||||
"""TestConfig that just runs the torch.nn.Module without compiling"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
# TODO: Deepcopy the torch.nn.Module, so that if the program is
|
||||
# stateful then it does not mutate the original compiled program.
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
outputs = getattr(artifact, item.symbol)(*item.inputs)
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = [outputs]
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
outputs=outputs))
|
||||
return result
|
|
@ -0,0 +1,47 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir.torchscript.annotations import extract_annotations
|
||||
|
||||
|
||||
class RefBackendTestConfig(TestConfig):
|
||||
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.backend = refjit.CompilerBackend()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
scripted = torch.jit.script(program)
|
||||
class_annotator = torch_mlir.ClassAnnotator()
|
||||
|
||||
extract_annotations(program, scripted, class_annotator)
|
||||
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
# Lower module in place.
|
||||
frontend_lowering.lower_object_graph(mb.module)
|
||||
return self.backend.compile(mb.module)
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
jit_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
numpy_inputs = [t.numpy() for t in item.inputs]
|
||||
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
|
||||
if isinstance(outputs, np.ndarray):
|
||||
outputs = [outputs]
|
||||
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
outputs=torch_outputs))
|
||||
return result
|
|
@ -0,0 +1,34 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
|
||||
|
||||
class TorchScriptTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module through TorchScript"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> torch.jit.ScriptModule:
|
||||
return torch.jit.script(program)
|
||||
|
||||
def run(self, artifact: torch.jit.ScriptModule, trace: Trace) -> Trace:
|
||||
# TODO: Deepcopy the torch.jit.ScriptModule, so that if the program is
|
||||
# stateful then it does not mutate the original compiled program.
|
||||
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
outputs = getattr(artifact, item.symbol)(*item.inputs)
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = [outputs]
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
outputs=outputs))
|
||||
return result
|
|
@ -0,0 +1,260 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
"""
|
||||
# End-to-end testing framework for TorchScript.
|
||||
|
||||
For the purposes of this framework, "end to end" means the first "end" is
|
||||
a `torch.nn.Module`, and the second "end" is execution.
|
||||
|
||||
## Architecture
|
||||
|
||||
A program for this testing framework is considered to be a `torch.nn.Module`,
|
||||
which has a public interface consisting of its methods and instance attributes.
|
||||
|
||||
A test in the framework consists conceputally of a list of calls into
|
||||
the methods of a module (TODO: extend to instance attributes). It is expected
|
||||
that the outputs match between the program run on a backend (controlled by
|
||||
a TestConfig) and a golden trace obtained by running on native Torch (without
|
||||
compiling or TorchScript'ing).
|
||||
"""
|
||||
|
||||
import abc
|
||||
from typing import Any, Callable, List, NamedTuple, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TraceItem(NamedTuple):
|
||||
# The externally visible symbol name that is called.
|
||||
# For example `"forward"` or `"submodule.forward"`.
|
||||
symbol: str
|
||||
# The list of inputs to the call.
|
||||
inputs: List[torch.Tensor] # TODO: Support more types.
|
||||
# The outputs from the call.
|
||||
# Sometimes this field is treated as golden outputs from a test.
|
||||
# Sometimes this field is treated as ignored, such as the input trace
|
||||
# provided to `TestConfig.run`.
|
||||
outputs: List[torch.Tensor] # TODO: Support more types
|
||||
|
||||
|
||||
# A trace of invocations to the program.
|
||||
# This is an ordered sequence of external invocations to a program's
|
||||
# public boundary.
|
||||
Trace = List[TraceItem]
|
||||
|
||||
# A type shared between the result of `TestConfig.compile` and the input
|
||||
# to `TestConfig.run`. Each backend will likely have a different definition of
|
||||
# this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
|
||||
class TestConfig(abc.ABC):
|
||||
"""The interface implemented by backends to run tests.
|
||||
|
||||
The testing framework expects to be able to call `compile` to compile
|
||||
a torch.nn.Module, and then pass the compiled artifact to `run` to run it.
|
||||
|
||||
Note that the definition of "compiled artifact" here is quite loose, and
|
||||
this interface allows for many different use cases besides simple testing.
|
||||
|
||||
For example, this interface can be overridden to be a "data collector"
|
||||
to gather information across all the test cases. For example,
|
||||
a compiler backend could override "compile" to just return some IR at a
|
||||
useful intermediate abstraction level (rather than the final compiled
|
||||
artifact), and then have "run" save this intermediate IR + the trace as
|
||||
input to some lower-level software stack's testing format.
|
||||
|
||||
The set of TestConfig's is expected to be pluggable and provided by
|
||||
users to suit their own needs. We provide a few configs out of the box
|
||||
in the `configs` submodule of this package, but those are intended
|
||||
to be for basic inspiration and enough for our own testing.
|
||||
Backends to npcomp will likely have more elaborate TestConfig's, such
|
||||
as `compile` being "compile for such-and-such DSP with these vectorization
|
||||
cost model flags" and `run` being "connect to Android phone with
|
||||
device ID 1234 and upload a program to run on it's DSP core, and also set
|
||||
power throttling settings to 'performance'".
|
||||
|
||||
That is also why this class is not called "backend", as it
|
||||
encapsulates potentially many specific details of the test configuration
|
||||
process as well. There isn't a general way to disentangle test configuration
|
||||
from the compile/run process specific to a logical backend, since each
|
||||
backend (compiler backend and runtime target) will have an arbitrarily
|
||||
wild and wonderful set of possible configurations that we cannot predict.
|
||||
"""
|
||||
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
|
||||
# We can have a helper class NpcompBackendTestConfig which does that.
|
||||
@abc.abstractmethod
|
||||
def compile(self, program: torch.nn.Module) -> CompiledArtifact:
|
||||
"""Compile the provided torch.nn.Module into a compiled artifact"""
|
||||
pass
|
||||
|
||||
# Any should match result of `compile`.
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, artifact: CompiledArtifact, trace: Trace) -> Trace:
|
||||
"""Run the compiled artifact produced by `compile`.
|
||||
|
||||
The backend should load the compiled artifact and call the
|
||||
symbol names listed in `trace` with their respective inputs (the outputs
|
||||
of `trace` should be ignored). A new identical trace with outputs
|
||||
populated should be returned.
|
||||
|
||||
This method should assume that `artifact` is being shared with
|
||||
multiple parallel invocations of `run`, and so it should not be mutated.
|
||||
This property is typicaly trivially satisfied for a true
|
||||
"compiled artifact", but some backends don't directly involve a
|
||||
compiled artifact per se (like a backend for which `CompiledArtifact` is
|
||||
`torch.nn.Module` and `run` just invokes the torch.nn.Module itself)
|
||||
|
||||
Args:
|
||||
artifact: a compiled artifact produced by `compile`.
|
||||
trace: The external invocations to stimulate the module.
|
||||
Returns:
|
||||
A trace with outputs recorded according to the results of running
|
||||
on this backend.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# Utilities for common testing trace generation.
|
||||
# Also, resets the random seed for reproducibility.
|
||||
# TODO: If generating in parallel, how to have manual_seed be local?
|
||||
class TestUtils:
|
||||
"""Utilities for executing a test.
|
||||
|
||||
Test cases are provided an instance of this class to make test cases
|
||||
more succinct.
|
||||
|
||||
For reproducibility, this class also resets the random seed.
|
||||
TODO: Figure out how to seed reset properly scoped to just a test case
|
||||
(such as when running tests in parallel)
|
||||
"""
|
||||
def __init__(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
# TODO: Add zeros/ones/etc. as convenient.
|
||||
def rand(self, *sizes):
|
||||
if len(sizes) == 0:
|
||||
return torch.rand([])
|
||||
return torch.rand(*sizes)
|
||||
|
||||
|
||||
class Test(NamedTuple):
|
||||
"""A description of a test as produced by the test frontend.
|
||||
"""
|
||||
# Stable name for error reporting.
|
||||
#
|
||||
# This name's stability is also useful for backend, which want to
|
||||
# generate their own lower-level test suites based on this framework.
|
||||
#
|
||||
# It is expected that those backends will need additional
|
||||
# metadata to describe their test configurations, so having a unique
|
||||
# key to keep that information associated is important.
|
||||
unique_name: str
|
||||
# A callable which produces the module under test.
|
||||
# This is a callable to allow lazily creating the module.
|
||||
program_factory: Callable[[], torch.nn.Module]
|
||||
# A callable which provides external stimuli to the module.
|
||||
# The first parameter is a torch.nn.Module (or a `_Tracer` wrapping that
|
||||
# module, actually).
|
||||
# The secon parameter is a `TestUtils` instance for convenience.
|
||||
program_invoker: Callable[[Any, TestUtils], None]
|
||||
|
||||
|
||||
class TestResult(NamedTuple):
|
||||
# Stable unique name for error reporting and test suite configuration.
|
||||
#
|
||||
# Tests frequently need some additional data (such as expected pass/fail
|
||||
# status, desired test configurations, etc.), and this gives a key to
|
||||
# associate to. This avoids extending this class arbitrarily for every
|
||||
# possible requirement from the test framework.
|
||||
#
|
||||
# This name is also useful for backends that are generating their own
|
||||
# lower-level test suites from this framework for the same reasons, though
|
||||
# those reasons are stronger because we cannot simply extend this
|
||||
# class.
|
||||
unique_name: str # Should match Test.unique_name for corresponding test.
|
||||
# The trace produced by the backend.
|
||||
trace: Trace
|
||||
# The golden trace which `trace` is expected to match.
|
||||
golden_trace: Trace
|
||||
|
||||
|
||||
class _Tracer:
|
||||
"""Wrapper around a `torch.nn.Module` that records calls into it.
|
||||
|
||||
The inputs and outputs of each call are recorded in a Trace.
|
||||
"""
|
||||
module: torch.nn.Module
|
||||
trace: Trace
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
self.module = module
|
||||
self.trace = []
|
||||
|
||||
def __getattr__(self, name):
|
||||
# TODO: Handle `module.foo.bar.baz` nesting.
|
||||
# For now, we are limited to attributes of the top-level module.
|
||||
def invoke(*args):
|
||||
raw_outputs = getattr(self.module, name)(*args)
|
||||
if isinstance(raw_outputs, torch.Tensor):
|
||||
outputs = [raw_outputs]
|
||||
self.trace.append(
|
||||
TraceItem(symbol=name, inputs=args, outputs=outputs))
|
||||
return raw_outputs
|
||||
|
||||
return invoke
|
||||
|
||||
def get_trace(self):
|
||||
return self.trace
|
||||
|
||||
|
||||
def _generate_golden_trace(test: Test) -> Trace:
|
||||
tracer = _Tracer(test.program_factory())
|
||||
test.program_invoker(tracer, TestUtils())
|
||||
return tracer.get_trace()
|
||||
|
||||
|
||||
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||
"""Invoke the given `Test`'s with the provided `TestConfig`."""
|
||||
results = []
|
||||
for test in tests:
|
||||
golden_trace = _generate_golden_trace(test)
|
||||
# TODO: Precompile everything in parallel.
|
||||
compiled = config.compile(test.program_factory())
|
||||
# TODO: Run in parallel.
|
||||
trace = config.run(compiled, golden_trace)
|
||||
results.append(
|
||||
TestResult(unique_name=test.unique_name,
|
||||
trace=trace,
|
||||
golden_trace=golden_trace))
|
||||
return results
|
||||
|
||||
|
||||
def report_results(results: List[TestResult]):
|
||||
"""Provide a basic error report summarizing various TestResult's."""
|
||||
for result in results:
|
||||
failed = False
|
||||
for item_num, (item, golden_item) in enumerate(
|
||||
zip(result.trace, result.golden_trace)):
|
||||
assert item.symbol == golden_item.symbol
|
||||
assert len(item.inputs) == len(golden_item.inputs)
|
||||
assert len(item.outputs) == len(golden_item.outputs)
|
||||
for input, golden_input in zip(item.inputs, golden_item.inputs):
|
||||
assert torch.allclose(input, golden_input)
|
||||
for output_num, (output, golden_output) in enumerate(
|
||||
zip(item.outputs, golden_item.outputs)):
|
||||
# TODO: Refine error message. Things to consider:
|
||||
# - Very large tensors -- don't spew, but give useful info
|
||||
# - Smaller tensors / primitives -- want to show exact values
|
||||
# - Machine parseable format?
|
||||
if not torch.allclose(output, golden_output):
|
||||
print(
|
||||
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
|
||||
)
|
||||
failed = True
|
||||
if failed:
|
||||
print('FAILURE "{}"'.format(result.unique_name))
|
||||
else:
|
||||
print('SUCCESS "{}"'.format(result.unique_name))
|
|
@ -0,0 +1,30 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from .framework import Test
|
||||
|
||||
# The global registry of tests.
|
||||
GLOBAL_TEST_REGISTRY = []
|
||||
|
||||
|
||||
def register_test_case(module_factory: Callable[[], torch.nn.Module]):
|
||||
"""Convenient decorator-based test registration.
|
||||
|
||||
Adds a `framework.Test` to the global test registry based on the decorated
|
||||
function. The test's `unique_name` is taken from the function name, the
|
||||
test's `program_factory` is taken from `module_factory`, and the
|
||||
`program_invoker` is the decorated function.
|
||||
"""
|
||||
def decorator(f):
|
||||
GLOBAL_TEST_REGISTRY.append(
|
||||
Test(unique_name=f.__name__,
|
||||
program_factory=module_factory,
|
||||
program_invoker=f))
|
||||
return f
|
||||
|
||||
return decorator
|
|
@ -0,0 +1,50 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.torchscript.annotations import (
|
||||
annotate_args, export, extract_annotations
|
||||
)
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32),
|
||||
([4, 5], torch.float32),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
module = MmModule()
|
||||
annotator = torch_mlir.ClassAnnotator()
|
||||
extract_annotations(module, torch.jit.script(module), annotator)
|
||||
print(annotator)
|
||||
|
||||
# CHECK: ClassAnnotator {
|
||||
# CHECK: ClassAnnotation('__torch__.MmModule') {
|
||||
# CHECK: MethodAnnotation('forward') {
|
||||
# CHECK: isExported = true
|
||||
# CHECK: argAnnotations =
|
||||
# CHECK: ArgAnnotation(0) {
|
||||
# CHECK: dtype = <none>
|
||||
# CHECK: shape = <none>
|
||||
# CHECK: }
|
||||
# CHECK: ArgAnnotation(1) {
|
||||
# CHECK: dtype = Float
|
||||
# CHECK: shape = [3, 4]
|
||||
# CHECK: }
|
||||
# CHECK: ArgAnnotation(2) {
|
||||
# CHECK: dtype = Float
|
||||
# CHECK: shape = [4, 5]
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
# CHECK: }
|
|
@ -0,0 +1,2 @@
|
|||
This directory is for testing the e2e_test framework itself.
|
||||
It is not for holding e2e tests themselves!!!
|
|
@ -0,0 +1,36 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
|
||||
# TODO: Refine messages.
|
||||
# CHECK: SUCCESS "MmModule_basic"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,42 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
|
||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, lhs, rhs):
|
||||
# Use torch.jit.is_scripting() to fake a miscompile.
|
||||
# The non-scripted code will take one path, and the scripted code
|
||||
# will take another path.
|
||||
if torch.jit.is_scripting():
|
||||
return torch.mm(rhs, lhs)
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
|
||||
# TODO: Refine error messages.
|
||||
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
|
||||
# CHECK: FAILURE "MmModule_basic"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue