Add E2E support for tests with heavy dependencies (heavydep tests).

The tests use the same (pure-Python) test framework as the
normal torchscript_e2e_test.sh, but the tests are added in
`build_tools/torchscript_e2e_heavydep_tests` instead of
`frontends/pytorch/e2e_testing/torchscript`. Any needed dependencies can
easily be configured in generate_serialized_tests.sh.

We add an initial machine translation model with a complex set of
dependencies to seed the curriculum there. I verified that this model
gets to the point of MLIR import (it fails there with a segfault due to
not being able to import the "Any" type).

This required moving a few files from the `torch_mlir` Python module
into multiple modules to isolate the code that depends on our C++
extensions (which now live in `torch_mlir` and
`torch_mlir_torchscript_e2e_test_configs`) from the pure Python code
(which now lives in `torch_mlir_torchscript`). This is an entirely
mechanical change, and lots of imports needed to be updated.

The dependency graph is:
```
       torch_mlir_torchscript_e2e_test_configs
                  /              |
                 /               |
                /                |
               V                 V
torch_mlir_torchscript       torch_mlir
```

The `torch_mlir_torchscript_e2e_test_configs` are then dependency-injected
into the `torch_mlir_torchscript` modules to successfully assemble a
working test harness (the code was already structured this way, but this
new file organization allows the isolation from C++ code to actually
happen).  This isolation is critical to allowing the serialized programs
to be transported across PyTorch versions and for the test harness to be
used seamlessly to generate the heavydep tests.

Also:
- Extend `_Tracer` class to support nested property (submodule) accesses.

Recommended review order:
- "user-level" docs in README.md
- code in `build_tools/torchscript_e2e_heavydep_tests`.
- changes in `torch_mlir_torchscript/e2e_test/framework.py`
- misc mechanical changes.
pull/268/head
Sean Silva 2021-07-09 12:22:45 -07:00
parent f168cacd6d
commit 453e29ea05
29 changed files with 582 additions and 128 deletions

View File

@ -129,11 +129,6 @@ export LDFLAGS=-fuse-ld=$(which ld.lld-$LLVM_VERSION)
### Vanilla - numpy-only, no pytorch ### Vanilla - numpy-only, no pytorch
```shell ```shell
# Install PyTorch. We currently track and require the nighly build.
# If a usable PyTorch package is installed, the default cmake settings will
# enable the PyTorch frontend.
pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
# Configure npcomp. # Configure npcomp.
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release . cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release .
@ -147,6 +142,11 @@ ninja check-npcomp
### With PyTorch integration ### With PyTorch integration
```shell ```shell
# Install PyTorch. We currently track and require the nighly build.
# If a usable PyTorch package is installed, the default cmake settings will
# enable the PyTorch frontend.
pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
cmake -DNPCOMP_ENABLE_PYTORCH=ON ... cmake -DNPCOMP_ENABLE_PYTORCH=ON ...
ninja check-frontends-pytorch # If building with PyTorch ninja check-frontends-pytorch # If building with PyTorch
``` ```
@ -210,6 +210,40 @@ echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env
tools/torchscript_e2e_test.sh --config=iree tools/torchscript_e2e_test.sh --config=iree
``` ```
### Additional end-to-end tests with heavy dependencies (heavydep tests)
Some end-to-end tests require additional dependencies which don't make sense to
include as part of the default npcomp setup. Additionally, these dependencies
often don't work with the same HEAD PyTorch version that npcomp builds against
at the C++ level.
We have a self-contained script that generates all the needed artifacts from a
self-contained virtual environment. It can be used like so:
```shell
# Build the virtual environment in the specified directory and generate the
# serialized test artifacts in the other specified directory.
# This command is safe to re-run if you have already built the virtual
# environment and just changed the tests.
build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh \
path/to/heavydep_venv \
path/to/heavydep_serialized_tests
# Add the --serialized-test-dir flag to point at the directory containing the
# serialized tests. All other functionality is the same as the normal invocation
# of torchscript_e2e_test.sh, but the serialized tests will be available.
tools/torchscript_e2e_test.sh --serialized-test-dir=/t/heavydep_serialized_tests
```
Note that the heavy dep tests are generally quite challenging, and we don't have
any that work yet. The tests use the same (pure-Python) test framework as the
normal torchscript_e2e_test.sh, but the tests are added in
`build_tools/torchscript_e2e_heavydep_tests` instead of
`frontends/pytorch/e2e_testing/torchscript`.
We rely critically on serialized TorchScript compatibility across PyTorch
versions to transport the tests + pure-Python compatibility of the `torch`
API, which has worked well so far.
### VSCode with a Docker Dev Image ### VSCode with a Docker Dev Image

View File

@ -0,0 +1,120 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Basic machine translation (MT) program.
#
# This is an ID's to ID's end-to-end program
import argparse
import tempfile
import unittest
import os
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from fairseq.data.dictionary import Dictionary
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.modules.multihead_attention import MultiheadAttention
from fairseq.models.transformer import TransformerModel
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
)
from fairseq.sequence_generator import SequenceGenerator
from fairseq.tasks.fairseq_task import LegacyFairseqTask
from fairseq import utils
from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir_torchscript.annotations import annotate_args, export
DEFAULT_TEST_VOCAB_SIZE = 100
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = _get_dummy_dictionary()
if getattr(self.args, "ctc", False):
self.dictionary.add_symbol("<ctc_blank>")
self.src_dict = self.dictionary
self.tgt_dict = self.dictionary
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.dictionary
def _get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
dummy_dict = Dictionary()
for id, _ in enumerate(range(vocab_size)):
dummy_dict.add_symbol("{}".format(id), n=1000)
return dummy_dict
def _get_dummy_task_and_parser():
parser = argparse.ArgumentParser(description="test_dummy_s2s_task",
argument_default=argparse.SUPPRESS)
DummyTask.add_args(parser)
args = parser.parse_args([])
task = DummyTask.setup_task(args)
return task, parser
class BasicMtModule(torch.nn.Module):
def __init__(self):
super().__init__()
task, parser = _get_dummy_task_and_parser()
TransformerModel.add_args(parser)
args = parser.parse_args([])
args.encoder_layers = 2
args.decoder_layers = 1
transformer_model = TransformerModel.build_model(args, task)
self.sequence_generator = SequenceGenerator(
[transformer_model],
task.tgt_dict,
beam_size=2,
no_repeat_ngram_size=2,
max_len_b=10,
)
# TODO: Support list/dict returns from functions.
# This will allow us to handle a variable number of sentences.
@export
@annotate_args([
None,
([2, -1], torch.long, True),
([2], torch.long, True),
])
def forward(self, src_tokens, src_lengths):
sample = {
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths
}
}
result = self.sequence_generator(sample)
return result[0][0]["tokens"], result[1][0]["tokens"]
EOS = BasicMtModule().sequence_generator.eos
@register_test_case(module_factory=lambda: BasicMtModule())
def BasicMtModule_basic(module, tu: TestUtils):
# Imagine random sentences from the vocabulary. Use a subset of the
# vocabulary that doesn't overlap with the EOS (which is usually the number
# 2).
MAX_SENTENCE_LENGTH = 10
src_tokens = torch.randint(DEFAULT_TEST_VOCAB_SIZE // 4,
DEFAULT_TEST_VOCAB_SIZE // 3,
(2, MAX_SENTENCE_LENGTH)).long()
# Append end-of-sentence symbol to the end of each sentence.
src_tokens = torch.cat((src_tokens, torch.LongTensor([[EOS], [EOS]])), -1)
src_lengths = torch.LongTensor([7, 10])
module.forward(src_tokens, src_lengths)

View File

@ -0,0 +1,28 @@
#!/bin/bash
set -euo pipefail
# Check that only two arugments are passed
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <venv_dir> <serialized_test_dir>"
echo 'Description:
venv_dir: directory to put the Python venv used for generating the serialized tests
serialized_test_dir: directory to write the generated serialized tests to
'
exit 1
fi
venv_dir=$1
serialized_test_dir=$2
here="$(realpath $(dirname $0))"
npcomp_src_root="$here/../../"
mkdir -p $venv_dir
mkdir -p $serialized_test_dir
python3 -m venv $venv_dir
source $venv_dir/bin/activate
python3 -m pip install fairseq fvcore sacremoses subword-nmt
cd "$npcomp_src_root"
export PYTHONPATH=${PYTHONPATH-}
source "$npcomp_src_root/.env"
python3 -m build_tools.torchscript_e2e_heavydep_tests.main --output_dir=$serialized_test_dir

View File

@ -0,0 +1,46 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import os
import pickle
import torch
from torch_mlir_torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_torchscript.e2e_test.framework import SerializableTest, generate_golden_trace
from torch_mlir_torchscript.annotations import extract_serializable_annotations
from . import basic_mt
def _get_argparse():
parser = argparse.ArgumentParser(
description="Generate assets for TorchScript E2E tests")
parser.add_argument("--output_dir", help="The directory to put assets in.")
return parser
def main():
args = _get_argparse().parse_args()
serializable_tests = []
for test in GLOBAL_TEST_REGISTRY:
trace = generate_golden_trace(test)
module = torch.jit.script(test.program_factory())
torchscript_module_bytes = module.save_to_buffer({
"annotations.pkl":
pickle.dumps(extract_serializable_annotations(module))
})
serializable_tests.append(
SerializableTest(unique_name=test.unique_name,
program=torchscript_module_bytes,
trace=trace))
for test in serializable_tests:
with open(os.path.join(args.output_dir, f"{test.unique_name}.pkl"),
"wb") as f:
pickle.dump(test, f)
if __name__ == "__main__":
main()

View File

@ -4,9 +4,9 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================

View File

@ -4,9 +4,9 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class BatchNorm1DModule(torch.nn.Module): class BatchNorm1DModule(torch.nn.Module):

View File

@ -3,9 +3,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================

View File

@ -4,9 +4,9 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# TODO: Support scalar !torch.int/!torch.float variants. Add support to # TODO: Support scalar !torch.int/!torch.float variants. Add support to
# ReduceOpVariants to implement them in terms of the tensor-only variants + # ReduceOpVariants to implement them in terms of the tensor-only variants +

View File

@ -3,15 +3,17 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse import argparse
import os
import pickle
import re import re
import sys import sys
from torch_mlir.torchscript.e2e_test.framework import run_tests from torch_mlir_torchscript.e2e_test.framework import run_tests
from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir_torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY from torch_mlir_torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
# Available test configs. # Available test configs.
from torch_mlir.torchscript.e2e_test.configs import ( from torch_mlir_torchscript_e2e_test_configs import (
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
) )
@ -56,6 +58,14 @@ Regular expression specifying which tests to include in this run.
default=False, default=False,
action='store_true', action='store_true',
help='report test results with additional detail') help='report test results with additional detail')
parser.add_argument('--serialized-test-dir', default=None, type=str, help='''
The directory containing serialized pre-built tests.
Right now, these are additional tests which require heavy Python dependencies
to generate (or cannot even be generated with the version of PyTorch used by
npcomp).
See `build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh`
for more information on building these artifacts.
''')
return parser return parser
def main(): def main():
@ -71,9 +81,16 @@ def main():
elif args.config == 'torchscript': elif args.config == 'torchscript':
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
all_tests = list(GLOBAL_TEST_REGISTRY)
if args.serialized_test_dir:
for root, dirs, files in os.walk(args.serialized_test_dir):
for filename in files:
with open(os.path.join(root, filename), 'rb') as f:
all_tests.append(pickle.load(f).as_test())
# Find the selected tests, and emit a diagnostic if none are found. # Find the selected tests, and emit a diagnostic if none are found.
tests = [ tests = [
test for test in GLOBAL_TEST_REGISTRY test for test in all_tests
if re.match(args.filter, test.unique_name) if re.match(args.filter, test.unique_name)
] ]
if len(tests) == 0: if len(tests) == 0:
@ -81,7 +98,7 @@ def main():
f'ERROR: the provided filter {args.filter!r} does not match any tests' f'ERROR: the provided filter {args.filter!r} does not match any tests'
) )
print('The available tests are:') print('The available tests are:')
for test in GLOBAL_TEST_REGISTRY: for test in all_tests:
print(test.unique_name) print(test.unique_name)
sys.exit(1) sys.exit(1)

View File

@ -5,9 +5,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================

View File

@ -5,9 +5,9 @@
import torch import torch
from torch import nn from torch import nn
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================

View File

@ -5,9 +5,9 @@
import torch import torch
import torchvision.models as models import torchvision.models as models
from torch_mlir.torchscript.e2e_test.framework import TestUtils from torch_mlir_torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case from torch_mlir_torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export from torch_mlir_torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================

View File

@ -16,58 +16,18 @@ import torch_mlir
# to be expressed conveniently and gives clearer error reports when # to be expressed conveniently and gives clearer error reports when
# the annotations aren't acceptable. # the annotations aren't acceptable.
# This module is kept separate from torch_mlir_torchscript.annotations so that
def export(fn): # we can use that module from code without C++ dependencies, which prevent us
"""Decorator that tells the npcomp compiler that a method is exported. # from interfacing the test framework across environments.
By default, no methods are exported, which is very important for
the compiler, because otherwise most Torch programs consist of a sea
of tiny exported functions with no rank or dtype information
(see `annotate_args`), which the compiler cannot do much with.
Note that this is different from `torch.jit.export`, which controls
which methods are scripted in the first place. For non-`forward` methods,
using this decorator usually means you also need `torch.jit.export`.
Conceptually, this decorator is annotating the scripted module, but is
applied to the original `torch.nn.Module` for convenience.
"""
fn._npcomp_export = True
return fn
ArgAnnotation = Tuple[List[int], torch.dtype]
# TODO: Replace with py3 extended argument annotations when available.
# See https://www.python.org/dev/peps/pep-0593/
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
"""Decorator that tells the npcomp compiler information about arguments.
The `annotations` should be a list of the same length as the number of
argument to the method (including `self`). Each list entry is either:
- None, corresponding to providing the compiler with no information.
- A 2-tuple consisting of a shape and a dtype, such as
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
indicated by using `-1` as the size. This provides the compiler a
guarantee that the argument will always dynamically have the described
shape and dtype.
"""
# TODO: Check the number of arguments matches the number of arg annotations.
def decorator(fn):
fn._npcomp_arg_annotations = annotations
return fn
return decorator
# Utilities for extracting decorated information into torch_mlir.ClassAnnotator. # Utilities for extracting decorated information into torch_mlir.ClassAnnotator.
def _recursively_extract_annotations( def _recursively_extract_annotations(
module: torch.nn.Module, scripted: torch.jit.ScriptModule, module: torch.nn.Module, scripted: torch.jit.ScriptModule,
class_annotator: torch_mlir.ClassAnnotator): class_annotator: torch_mlir.ClassAnnotator):
assert module.__class__.__name__ == scripted.original_name, "script module does not come from specified module" assert module.__class__.__name__ == scripted.original_name or (
isinstance(module, torch.jit.RecursiveScriptModule) and module is
scripted), "script module does not come from specified module"
# Extract information on methods. # Extract information on methods.
for method_name, scripted_method in scripted.__dict__.items(): for method_name, scripted_method in scripted.__dict__.items():

View File

@ -0,0 +1,128 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import List, Optional, Tuple, NamedTuple
import torch
# Decorators
# Currently, these decorators are very low-level and map 1:1 with
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
# be a more elaborate Python layer which allows all the different annotations
# to be expressed conveniently and gives clearer error reports when
# the annotations aren't acceptable.
# This module is kept separate from torch_mlir.torchscript_annotations so that
# we can use this from code without C++ dependencies, which prevent us from
# interfacing the test framework across environments.
# Attribute names used for annotations.
# These should be kept in sync with their use in
# `frontends/pytorch/python/torch_mlir/torchscript_annotations.py`.
NPCOMP_EXPORT_ATTR_NAME = '_npcomp_export'
NPCOMP_ARG_ANNOTATIONS_ATTR_NAME = '_npcomp_arg_annotations'
def export(fn):
"""Decorator that tells the npcomp compiler that a method is exported.
By default, no methods are exported, which is very important for
the compiler, because otherwise most Torch programs consist of a sea
of tiny exported functions with no rank or dtype information
(see `annotate_args`), which the compiler cannot do much with.
Note that this is different from `torch.jit.export`, which controls
which methods are scripted in the first place. For non-`forward` methods,
using this decorator usually means you also need `torch.jit.export`.
Conceptually, this decorator is annotating the scripted module, but is
applied to the original `torch.nn.Module` for convenience.
"""
setattr(fn, NPCOMP_EXPORT_ATTR_NAME, True)
return fn
ArgAnnotation = Tuple[List[int], torch.dtype]
# TODO: Replace with py3 extended argument annotations when available.
# See https://www.python.org/dev/peps/pep-0593/
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
"""Decorator that tells the npcomp compiler information about arguments.
The `annotations` should be a list of the same length as the number of
argument to the method (including `self`). Each list entry is either:
- None, corresponding to providing the compiler with no information.
- A 2-tuple consisting of a shape and a dtype, such as
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
indicated by using `-1` as the size. This provides the compiler a
guarantee that the argument will always dynamically have the described
shape and dtype.
"""
# TODO: Check the number of arguments matches the number of arg annotations.
def decorator(fn):
setattr(fn, NPCOMP_ARG_ANNOTATIONS_ATTR_NAME, annotations)
return fn
return decorator
class SerializableMethodAnnotation(NamedTuple):
method_name: str
export: Optional[bool]
arg_annotations: Optional[List[ArgAnnotation]]
class SerializableModuleAnnotations(NamedTuple):
method_annotations: List[SerializableMethodAnnotation]
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
def extract_serializable_annotations(
module: torch.nn.Module) -> SerializableModuleAnnotations:
module_annotations = SerializableModuleAnnotations(
method_annotations=[], submodule_annotations=[])
# Extract information on methods.
for method_name, method in module.__dict__.items():
# See if it is a method.
if not callable(method):
continue
export = None
arg_annotations = None
if hasattr(method, NPCOMP_EXPORT_ATTR_NAME):
export = method._npcomp_export
if hasattr(method, NPCOMP_ARG_ANNOTATIONS_ATTR_NAME):
arg_annotations = method._npcomp_arg_annotations
if export is not None and arg_annotations is not None:
module_annotations.method_annotations.append(
SerializableMethodAnnotation(method_name=method_name,
export=export,
arg_annotations=arg_annotations))
# Recurse.
for name, child in module.named_children():
annotations = extract_serializable_annotations(child)
module_annotations.submodule_annotations.append((name, annotations))
return module_annotations
def apply_serializable_annotations(module: torch.nn.Module,
annotations: SerializableModuleAnnotations):
# Apply annotations to methods.
for method_annotation in annotations.method_annotations:
# Imitate use of the decorators to keep a source of truth there.
if method_annotation.export is not None:
setattr(module, method_annotation.method_name,
export(getattr(module, method_annotation.method_name)))
if method_annotation.arg_annotations is not None:
setattr(
module, method_annotation.method_name,
annotate_args(method_annotation.arg_annotations)(getattr(
module, method_annotation.method_name)))
# Recurse.
for name, submodule_annotations in annotations.submodule_annotations:
child = getattr(module, name)
apply_serializable_annotations(child, submodule_annotations)

View File

@ -22,8 +22,13 @@ compiling or TorchScript'ing).
import abc import abc
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
import io
import pickle
import torch import torch
from ..annotations import apply_serializable_annotations
class TraceItem(NamedTuple): class TraceItem(NamedTuple):
# The externally visible symbol name that is called. # The externally visible symbol name that is called.
@ -162,6 +167,57 @@ class Test(NamedTuple):
program_invoker: Callable[[Any, TestUtils], None] program_invoker: Callable[[Any, TestUtils], None]
class SerializableTest(NamedTuple):
"""A self-contained representation of a test that can be pickled.
We use serialized TorchScript programs here for two reasons:
1. The PyTorch pickling story isn't great, so in order to reliably pickle
this class, we rely on having the serialized bytes for the TorchScript
module already given to us.
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
the fact that `torch.nn.Module` cannot be deserialized without pulling
in the same set of Python dependencies that were used to serialize it
in the first place. This would defeat one of the
main use cases of this class, which is to transport a test from an
environment with a set of heavy dependencies to a dependency-light one.
Since TorchScript modules are self-contained, they fit the bill
perfectly.
"""
# See unique_name on `Test`.
unique_name: str
# Serialized TorchScript program.
program: bytes
# Trace for execution testing.
trace: Trace
def as_test(self) -> Test:
"""Create a `Test` from this class."""
# Conform the serialized program to the interface expected by Test.
# This is a bit of a hack, but it's the only way to keep the layering
# straight.
def factory():
_extra_files = {"annotations.pkl": ""}
module = torch.jit.load(io.BytesIO(self.program),
_extra_files=_extra_files)
# Load the pickled annotations.
annotations = pickle.loads(_extra_files["annotations.pkl"])
apply_serializable_annotations(module, annotations)
return module
def invoker(module, tu):
for item in self.trace:
attr = module
for part in item.symbol.split("."):
attr = getattr(attr, part)
attr(*item.inputs)
return Test(
unique_name=self.unique_name,
program_factory=factory,
program_invoker=invoker,
)
class TestResult(NamedTuple): class TestResult(NamedTuple):
# Stable unique name for error reporting and test suite configuration. # Stable unique name for error reporting and test suite configuration.
# #
@ -188,39 +244,44 @@ class TestResult(NamedTuple):
class _Tracer: class _Tracer:
"""Wrapper around a `torch.nn.Module` that records calls into it. """Wrapper around a `torch.nn.Module` that records calls into it.
The inputs and outputs of each call are recorded in a Trace. The inputs and outputs of each call are recorded in a Trace. Recursive
property accesses are also traced.
""" """
module: torch.nn.Module def __init__(self, wrapped, property_base_path: List[str], trace: Trace):
trace: Trace self.__wrapped__ = wrapped
self.__trace__ = trace
self.__property_base_path__ = property_base_path
def __init__(self, module: torch.nn.Module): def __call__(self, *args, **kwargs):
self.module = module raw_outputs = self.__wrapped__(*args, **kwargs)
self.trace = [] if isinstance(raw_outputs, torch.Tensor):
outputs = [raw_outputs]
elif isinstance(raw_outputs, tuple) and all(
isinstance(o, torch.Tensor) for o in raw_outputs):
outputs = raw_outputs
else:
raise Exception("unimplemented: non-Tensor output from function")
self.__trace__.append(
TraceItem(symbol=".".join(self.__property_base_path__),
inputs=args,
outputs=outputs))
return raw_outputs
def __getattr__(self, name): def __getattr__(self, name):
# TODO: Handle `module.foo.bar.baz` nesting. return _Tracer(getattr(self.__wrapped__, name),
# For now, we are limited to attributes of the top-level module. self.__property_base_path__ + [name], self.__trace__)
def invoke(*args):
raw_outputs = getattr(self.module, name)(*args)
if isinstance(raw_outputs, torch.Tensor):
outputs = [raw_outputs]
else:
raise Exception(
"unimplemented: non-Tensor output from function")
self.trace.append(
TraceItem(symbol=name, inputs=args, outputs=outputs))
return raw_outputs
return invoke
def get_trace(self):
return self.trace
def _generate_golden_trace(test: Test) -> Trace: def generate_golden_trace(test: Test) -> Trace:
tracer = _Tracer(test.program_factory()) """Generate a trace with the original program.
If the original program is deterministic, then this the produced trace is
suitable as a golden trace to compare against.
"""
trace = []
tracer = _Tracer(test.program_factory(), [], trace)
test.program_invoker(tracer, TestUtils()) test.program_invoker(tracer, TestUtils())
return tracer.get_trace() return trace
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
@ -229,9 +290,16 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
for test in tests: for test in tests:
# TODO: Precompile everything in parallel. # TODO: Precompile everything in parallel.
try: try:
golden_trace = _generate_golden_trace(test) golden_trace = generate_golden_trace(test)
compiled = config.compile(test.program_factory()) compiled = config.compile(test.program_factory())
except Exception as e: except Exception as e:
# Useful for debugging:
# ```
# raise
# ```
# This will give the full traceback rather than giving just
# the stringified exception in the report.
# TODO: Capture the traceback and make it available in the report.
results.append( results.append(
TestResult(unique_name=test.unique_name, TestResult(unique_name=test.unique_name,
compilation_error=str(e), compilation_error=str(e),

View File

@ -161,7 +161,7 @@ class SingleTestReport:
f = io.StringIO() f = io.StringIO()
p = lambda *x: print(*x, file=f) p = lambda *x: print(*x, file=f)
if self.result.compilation_error is not None: if self.result.compilation_error is not None:
return 'compilation error' + self.result.compilation_error return 'Compilation error: ' + self.result.compilation_error
for report in self.item_reports: for report in self.item_reports:
if report.failed: if report.failed:
p(report.error_str()) p(report.error_str())

View File

@ -7,7 +7,7 @@ from typing import Any
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
class NativeTorchTestConfig(TestConfig): class NativeTorchTestConfig(TestConfig):

View File

@ -15,8 +15,8 @@ from mlir.passmanager import PassManager
import torch_mlir import torch_mlir
from npcomp.compiler.pytorch.backend import refjit from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.pytorch.backend.abc import NpcompBackend from npcomp.compiler.pytorch.backend.abc import NpcompBackend
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir.torchscript.annotations import extract_annotations from torch_mlir.torchscript_annotations import extract_annotations
class PrettyErrorReportForIrOperation(object): class PrettyErrorReportForIrOperation(object):
def __init__(self, module, module_name_for_ir_dump: str): def __init__(self, module, module_name_for_ir_dump: str):

View File

@ -7,7 +7,7 @@ from typing import Any
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
class TorchScriptTestConfig(TestConfig): class TorchScriptTestConfig(TestConfig):
@ -24,7 +24,10 @@ class TorchScriptTestConfig(TestConfig):
result: Trace = [] result: Trace = []
for item in trace: for item in trace:
outputs = getattr(artifact, item.symbol)(*item.inputs) attr = artifact
for part in item.symbol.split('.'):
attr = getattr(attr, part)
outputs = attr(*item.inputs)
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = [outputs] outputs = [outputs]
result.append( result.append(

View File

@ -40,6 +40,11 @@ setup(
"": "./python", "": "./python",
}, },
packages=find_packages("./python", include=[ packages=find_packages("./python", include=[
"torch_mlir", "torch_mlir.*", "torch_mlir",
"torch_mlir.*",
"torch_mlir_torchscript",
"torch_mlir_torchscript.*",
"torch_mlir_torchscript_e2e_test_configs",
"torch_mlir_torchscript_e2e_test_configs.*",
]), ]),
) )

View File

@ -7,9 +7,8 @@
import torch import torch
import torch_mlir import torch_mlir
from torch_mlir.torchscript.annotations import ( from torch_mlir_torchscript.annotations import annotate_args, export
annotate_args, export, extract_annotations from torch_mlir.torchscript_annotations import extract_annotations
)
class MmModule(torch.nn.Module): class MmModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -6,10 +6,10 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir_torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
class MmModule(torch.nn.Module): class MmModule(torch.nn.Module):

View File

@ -6,10 +6,10 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir_torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
class MmModule(torch.nn.Module): class MmModule(torch.nn.Module):
@ -26,7 +26,7 @@ class MmModule(torch.nn.Module):
# CHECK: FAIL - "MmModule_basic" # CHECK: FAIL - "MmModule_basic"
# CHECK: compilation error # CHECK: Compilation error:
# Assume that the diagnostic from the TorchScript compiler will at least contain # Assume that the diagnostic from the TorchScript compiler will at least contain
# the offending "return 3". # the offending "return 3".
# CHECK: return 3 # CHECK: return 3

View File

@ -6,10 +6,10 @@
import torch import torch
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir.torchscript.e2e_test.reporting import report_results from torch_mlir_torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
class MmModule(torch.nn.Module): class MmModule(torch.nn.Module):

View File

@ -0,0 +1,46 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | FileCheck %s
import torch
from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
from torch_mlir_torchscript.e2e_test.reporting import report_results
from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
class Submodule2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
class Submodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m2 = Submodule2()
class ModuleWithSubmodule(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = Submodule()
# CHECK: PASS - "ModuleWithSubmodule_basic"
@register_test_case(module_factory=lambda: ModuleWithSubmodule())
def ModuleWithSubmodule_basic(module, tu: TestUtils):
module.m.m2.forward(tu.rand(4, 4), tu.rand(4, 4))
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, set())
if __name__ == '__main__':
main()