Add end-to-end testing framework for TorchScript.

The E2E tests can be run with
```
npcpy frontends/pytorch/e2e_testing/torchscript/main.py
```

This commit adds a couple items supporting that end, including new sugar
for annotations (no more raw use of ClassAnnotator!).

Recommended review order:

1. `frontends/pytorch/e2e_testing/torchscript/main.py` for
   the harness + `basic.py` in that directory for examples of tests.
2. Annotation sugar in `frontends/pytorch/python/torch_mlir/torchscript/annotations.py`
   and unittest in `frontends/pytorch/test/ivalue_import/annotations/sugar.py`
3. Global test registry / sugar in
   `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py`
4. `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py`
   for the meat of the testing framework (start at `run_tests`), and
   looking at the backend configs in
   `frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs`
   for examples of backends. This is likely the bulk of review time.
5. Unit tests of the framework logic in `frontends/pytorch/test/torchscript_e2e_test`

There's TODO's scattered throughout, but this seems functional enough to
start pulling stuff into and kicking the tires. A few missing pieces:

1. Marking test expected pass/fail per backend.
2. Figuring out how best to fit this into dev workflows.
3. IREE TestConfig.

Also, forgive this Python newbie... Any advice on Python code structure
/ library design would be much appreciated.
pull/208/head
Sean Silva 2021-04-19 15:12:29 -07:00
parent fef1733e12
commit 39d50ccf0d
15 changed files with 745 additions and 0 deletions

View File

@ -0,0 +1,69 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from torch_mlir.torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export
# ==============================================================================
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32),
([-1, -1], torch.float32),
])
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
@register_test_case(module_factory=lambda: MmModule())
def MmModule_chained(module, tu: TestUtils):
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
module.forward(res, res)
# ==============================================================================
class TanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, -1], torch.float32),
])
def forward(self, x):
return torch.tanh(x)
@register_test_case(module_factory=lambda: TanhModule())
def TanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 1))
# ==============================================================================
class MmTanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32),
([-1, -1], torch.float32),
])
def forward(self, lhs, rhs):
return torch.tanh(self.matmul(lhs, rhs))
def matmul(self, lhs, rhs):
return torch.mm(lhs, rhs)
@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2), tu.rand(2, 4))

View File

@ -0,0 +1,40 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
# Available test configs.
from torch_mlir.torchscript.e2e_test.configs import (
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
)
# Import tests to register them in the global registry.
import basic
def main():
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'],
default='refbackend',
help='''
Meaning of options:
"refbackend": run through npcomp's RefBackend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
args = parser.parse_args()
if args.config == 'refbackend':
config = RefBackendTestConfig()
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
elif args.config == 'torchscript':
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,95 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import List, Optional, Tuple
import torch
import torch_mlir
# Decorators
# Currently, these decorators are very low-level and map 1:1 with
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
# be a more elaborate Python layer which allows all the different annotations
# to be expressed conveniently and gives clearer error reports when
# the annotations aren't acceptable.
def export(fn):
"""Decorator that tells the npcomp compiler that a method is exported.
By default, no methods are exported, which is very important for
the compiler, because otherwise most Torch programs consist of a sea
of tiny exported functions with no rank or dtype information
(see `annotate_args`), which the compiler cannot do much with.
Note that this is different from `torch.jit.export`, which controls
which methods are scripted in the first place. For non-`forward` methods,
using this decorator usually means you also need `torch.jit.export`.
Conceptually, this decorator is annotating the scripted module, but is
applied to the original `torch.nn.Module` for convenience.
"""
fn._npcomp_export = True
return fn
ArgAnnotation = Tuple[List[int], torch.dtype]
# TODO: Replace with py3 extended argument annotations when available.
# See https://www.python.org/dev/peps/pep-0593/
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
"""Decorator that tells the npcomp compiler information about arguments.
The `annotations` should be a list of the same length as the number of
argument to the method (including `self`). Each list entry is either:
- None, corresponding to providing the compiler with no information.
- A 2-tuple consisting of a shape and a dtype, such as
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
indicated by using `-1` as the size. This provides the compiler a
guarantee that the argument will always dynamically have the described
shape and dtype.
"""
# TODO: Check the number of arguments matches the number of arg annotations.
def decorator(fn):
fn._npcomp_arg_annotations = annotations
return fn
return decorator
# Utilities for extracting decorated information into torch_mlir.ClassAnnotator.
def _recursively_extract_annotations(
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
class_annotator: torch_mlir.ClassAnnotator):
assert module.__class__.__name__ == scripted.original_name, "script module does not come from specified module"
# Extract information on methods.
for method_name, scripted_method in scripted.__dict__.items():
if not isinstance(scripted_method, torch.ScriptMethod):
continue
method = getattr(module, method_name)
if hasattr(method, '_npcomp_export'):
class_annotator.exportPath(scripted._c._type(), [method_name])
if hasattr(method, '_npcomp_arg_annotations'):
class_annotator.annotateShapesAndDtypes(
scripted._c._type(), [method_name],
method._npcomp_arg_annotations)
# Recurse.
for name, child in module.named_children():
scripted_child = getattr(scripted, name)
_recursively_extract_annotations(child, scripted_child,
class_annotator)
def extract_annotations(program: torch.nn.Module,
scripted: torch.jit.ScriptModule,
class_annotator: torch_mlir.ClassAnnotator):
"""Populate the ClassAnnotator with annotations extracted from `program`."""
class_annotator.exportNone(scripted._c._type())
_recursively_extract_annotations(program, scripted, class_annotator)

View File

@ -0,0 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .ref_backend import RefBackendTestConfig
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig

View File

@ -0,0 +1,33 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import copy
from typing import Any
import torch
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
class NativeTorchTestConfig(TestConfig):
"""TestConfig that just runs the torch.nn.Module without compiling"""
def __init__(self):
super().__init__()
def compile(self, program: torch.nn.Module) -> torch.nn.Module:
return program
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
# TODO: Deepcopy the torch.nn.Module, so that if the program is
# stateful then it does not mutate the original compiled program.
result: Trace = []
for item in trace:
outputs = getattr(artifact, item.symbol)(*item.inputs)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
outputs=outputs))
return result

View File

@ -0,0 +1,47 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Any
import numpy as np
import torch
import torch_mlir
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir.torchscript.annotations import extract_annotations
class RefBackendTestConfig(TestConfig):
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
def __init__(self):
super().__init__()
self.backend = refjit.CompilerBackend()
def compile(self, program: torch.nn.Module) -> Any:
mb = torch_mlir.ModuleBuilder()
scripted = torch.jit.script(program)
class_annotator = torch_mlir.ClassAnnotator()
extract_annotations(program, scripted, class_annotator)
mb.import_module(scripted._c, class_annotator)
# Lower module in place.
frontend_lowering.lower_object_graph(mb.module)
return self.backend.compile(mb.module)
def run(self, artifact: Any, trace: Trace) -> Trace:
jit_module = self.backend.load(artifact)
result: Trace = []
for item in trace:
numpy_inputs = [t.numpy() for t in item.inputs]
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
if isinstance(outputs, np.ndarray):
outputs = [outputs]
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
outputs=torch_outputs))
return result

View File

@ -0,0 +1,34 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import copy
from typing import Any
import torch
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
class TorchScriptTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module through TorchScript"""
def __init__(self):
super().__init__()
def compile(self, program: torch.nn.Module) -> torch.jit.ScriptModule:
return torch.jit.script(program)
def run(self, artifact: torch.jit.ScriptModule, trace: Trace) -> Trace:
# TODO: Deepcopy the torch.jit.ScriptModule, so that if the program is
# stateful then it does not mutate the original compiled program.
result: Trace = []
for item in trace:
outputs = getattr(artifact, item.symbol)(*item.inputs)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
outputs=outputs))
return result

View File

@ -0,0 +1,260 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
# End-to-end testing framework for TorchScript.
For the purposes of this framework, "end to end" means the first "end" is
a `torch.nn.Module`, and the second "end" is execution.
## Architecture
A program for this testing framework is considered to be a `torch.nn.Module`,
which has a public interface consisting of its methods and instance attributes.
A test in the framework consists conceputally of a list of calls into
the methods of a module (TODO: extend to instance attributes). It is expected
that the outputs match between the program run on a backend (controlled by
a TestConfig) and a golden trace obtained by running on native Torch (without
compiling or TorchScript'ing).
"""
import abc
from typing import Any, Callable, List, NamedTuple, TypeVar
import torch
class TraceItem(NamedTuple):
# The externally visible symbol name that is called.
# For example `"forward"` or `"submodule.forward"`.
symbol: str
# The list of inputs to the call.
inputs: List[torch.Tensor] # TODO: Support more types.
# The outputs from the call.
# Sometimes this field is treated as golden outputs from a test.
# Sometimes this field is treated as ignored, such as the input trace
# provided to `TestConfig.run`.
outputs: List[torch.Tensor] # TODO: Support more types
# A trace of invocations to the program.
# This is an ordered sequence of external invocations to a program's
# public boundary.
Trace = List[TraceItem]
# A type shared between the result of `TestConfig.compile` and the input
# to `TestConfig.run`. Each backend will likely have a different definition of
# this type.
CompiledArtifact = TypeVar('CompiledArtifact')
class TestConfig(abc.ABC):
"""The interface implemented by backends to run tests.
The testing framework expects to be able to call `compile` to compile
a torch.nn.Module, and then pass the compiled artifact to `run` to run it.
Note that the definition of "compiled artifact" here is quite loose, and
this interface allows for many different use cases besides simple testing.
For example, this interface can be overridden to be a "data collector"
to gather information across all the test cases. For example,
a compiler backend could override "compile" to just return some IR at a
useful intermediate abstraction level (rather than the final compiled
artifact), and then have "run" save this intermediate IR + the trace as
input to some lower-level software stack's testing format.
The set of TestConfig's is expected to be pluggable and provided by
users to suit their own needs. We provide a few configs out of the box
in the `configs` submodule of this package, but those are intended
to be for basic inspiration and enough for our own testing.
Backends to npcomp will likely have more elaborate TestConfig's, such
as `compile` being "compile for such-and-such DSP with these vectorization
cost model flags" and `run` being "connect to Android phone with
device ID 1234 and upload a program to run on it's DSP core, and also set
power throttling settings to 'performance'".
That is also why this class is not called "backend", as it
encapsulates potentially many specific details of the test configuration
process as well. There isn't a general way to disentangle test configuration
from the compile/run process specific to a logical backend, since each
backend (compiler backend and runtime target) will have an arbitrarily
wild and wonderful set of possible configurations that we cannot predict.
"""
# This is not a frontend-lowered module, to allow various testing at the PyTorch level.
# We can have a helper class NpcompBackendTestConfig which does that.
@abc.abstractmethod
def compile(self, program: torch.nn.Module) -> CompiledArtifact:
"""Compile the provided torch.nn.Module into a compiled artifact"""
pass
# Any should match result of `compile`.
@abc.abstractmethod
def run(self, artifact: CompiledArtifact, trace: Trace) -> Trace:
"""Run the compiled artifact produced by `compile`.
The backend should load the compiled artifact and call the
symbol names listed in `trace` with their respective inputs (the outputs
of `trace` should be ignored). A new identical trace with outputs
populated should be returned.
This method should assume that `artifact` is being shared with
multiple parallel invocations of `run`, and so it should not be mutated.
This property is typicaly trivially satisfied for a true
"compiled artifact", but some backends don't directly involve a
compiled artifact per se (like a backend for which `CompiledArtifact` is
`torch.nn.Module` and `run` just invokes the torch.nn.Module itself)
Args:
artifact: a compiled artifact produced by `compile`.
trace: The external invocations to stimulate the module.
Returns:
A trace with outputs recorded according to the results of running
on this backend.
"""
pass
# Utilities for common testing trace generation.
# Also, resets the random seed for reproducibility.
# TODO: If generating in parallel, how to have manual_seed be local?
class TestUtils:
"""Utilities for executing a test.
Test cases are provided an instance of this class to make test cases
more succinct.
For reproducibility, this class also resets the random seed.
TODO: Figure out how to seed reset properly scoped to just a test case
(such as when running tests in parallel)
"""
def __init__(self):
torch.manual_seed(0)
# TODO: Add zeros/ones/etc. as convenient.
def rand(self, *sizes):
if len(sizes) == 0:
return torch.rand([])
return torch.rand(*sizes)
class Test(NamedTuple):
"""A description of a test as produced by the test frontend.
"""
# Stable name for error reporting.
#
# This name's stability is also useful for backend, which want to
# generate their own lower-level test suites based on this framework.
#
# It is expected that those backends will need additional
# metadata to describe their test configurations, so having a unique
# key to keep that information associated is important.
unique_name: str
# A callable which produces the module under test.
# This is a callable to allow lazily creating the module.
program_factory: Callable[[], torch.nn.Module]
# A callable which provides external stimuli to the module.
# The first parameter is a torch.nn.Module (or a `_Tracer` wrapping that
# module, actually).
# The secon parameter is a `TestUtils` instance for convenience.
program_invoker: Callable[[Any, TestUtils], None]
class TestResult(NamedTuple):
# Stable unique name for error reporting and test suite configuration.
#
# Tests frequently need some additional data (such as expected pass/fail
# status, desired test configurations, etc.), and this gives a key to
# associate to. This avoids extending this class arbitrarily for every
# possible requirement from the test framework.
#
# This name is also useful for backends that are generating their own
# lower-level test suites from this framework for the same reasons, though
# those reasons are stronger because we cannot simply extend this
# class.
unique_name: str # Should match Test.unique_name for corresponding test.
# The trace produced by the backend.
trace: Trace
# The golden trace which `trace` is expected to match.
golden_trace: Trace
class _Tracer:
"""Wrapper around a `torch.nn.Module` that records calls into it.
The inputs and outputs of each call are recorded in a Trace.
"""
module: torch.nn.Module
trace: Trace
def __init__(self, module: torch.nn.Module):
self.module = module
self.trace = []
def __getattr__(self, name):
# TODO: Handle `module.foo.bar.baz` nesting.
# For now, we are limited to attributes of the top-level module.
def invoke(*args):
raw_outputs = getattr(self.module, name)(*args)
if isinstance(raw_outputs, torch.Tensor):
outputs = [raw_outputs]
self.trace.append(
TraceItem(symbol=name, inputs=args, outputs=outputs))
return raw_outputs
return invoke
def get_trace(self):
return self.trace
def _generate_golden_trace(test: Test) -> Trace:
tracer = _Tracer(test.program_factory())
test.program_invoker(tracer, TestUtils())
return tracer.get_trace()
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
"""Invoke the given `Test`'s with the provided `TestConfig`."""
results = []
for test in tests:
golden_trace = _generate_golden_trace(test)
# TODO: Precompile everything in parallel.
compiled = config.compile(test.program_factory())
# TODO: Run in parallel.
trace = config.run(compiled, golden_trace)
results.append(
TestResult(unique_name=test.unique_name,
trace=trace,
golden_trace=golden_trace))
return results
def report_results(results: List[TestResult]):
"""Provide a basic error report summarizing various TestResult's."""
for result in results:
failed = False
for item_num, (item, golden_item) in enumerate(
zip(result.trace, result.golden_trace)):
assert item.symbol == golden_item.symbol
assert len(item.inputs) == len(golden_item.inputs)
assert len(item.outputs) == len(golden_item.outputs)
for input, golden_input in zip(item.inputs, golden_item.inputs):
assert torch.allclose(input, golden_input)
for output_num, (output, golden_output) in enumerate(
zip(item.outputs, golden_item.outputs)):
# TODO: Refine error message. Things to consider:
# - Very large tensors -- don't spew, but give useful info
# - Smaller tensors / primitives -- want to show exact values
# - Machine parseable format?
if not torch.allclose(output, golden_output):
print(
f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"'
)
failed = True
if failed:
print('FAILURE "{}"'.format(result.unique_name))
else:
print('SUCCESS "{}"'.format(result.unique_name))

View File

@ -0,0 +1,30 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Callable
import torch
from .framework import Test
# The global registry of tests.
GLOBAL_TEST_REGISTRY = []
def register_test_case(module_factory: Callable[[], torch.nn.Module]):
"""Convenient decorator-based test registration.
Adds a `framework.Test` to the global test registry based on the decorated
function. The test's `unique_name` is taken from the function name, the
test's `program_factory` is taken from `module_factory`, and the
`program_invoker` is the decorated function.
"""
def decorator(f):
GLOBAL_TEST_REGISTRY.append(
Test(unique_name=f.__name__,
program_factory=module_factory,
program_invoker=f))
return f
return decorator

View File

@ -0,0 +1,50 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | FileCheck %s
import torch
import torch_mlir
from torch_mlir.torchscript.annotations import (
annotate_args, export, extract_annotations
)
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32),
([4, 5], torch.float32),
])
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
module = MmModule()
annotator = torch_mlir.ClassAnnotator()
extract_annotations(module, torch.jit.script(module), annotator)
print(annotator)
# CHECK: ClassAnnotator {
# CHECK: ClassAnnotation('__torch__.MmModule') {
# CHECK: MethodAnnotation('forward') {
# CHECK: isExported = true
# CHECK: argAnnotations =
# CHECK: ArgAnnotation(0) {
# CHECK: dtype = <none>
# CHECK: shape = <none>
# CHECK: }
# CHECK: ArgAnnotation(1) {
# CHECK: dtype = Float
# CHECK: shape = [3, 4]
# CHECK: }
# CHECK: ArgAnnotation(2) {
# CHECK: dtype = Float
# CHECK: shape = [4, 5]
# CHECK: }
# CHECK: }
# CHECK: }
# CHECK: }

View File

@ -0,0 +1,2 @@
This directory is for testing the e2e_test framework itself.
It is not for holding e2e tests themselves!!!

View File

@ -0,0 +1,36 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | FileCheck %s
import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
# TODO: Refine messages.
# CHECK: SUCCESS "MmModule_basic"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,42 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | FileCheck %s
import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, lhs, rhs):
# Use torch.jit.is_scripting() to fake a miscompile.
# The non-scripted code will take one path, and the scripted code
# will take another path.
if torch.jit.is_scripting():
return torch.mm(rhs, lhs)
return torch.mm(lhs, rhs)
# TODO: Refine error messages.
# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward"
# CHECK: FAILURE "MmModule_basic"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results)
if __name__ == '__main__':
main()