mirror of https://github.com/llvm/torch-mlir
Add E2E support for tests with heavy dependencies (heavydep tests).
The tests use the same (pure-Python) test framework as the normal torchscript_e2e_test.sh, but the tests are added in `build_tools/torchscript_e2e_heavydep_tests` instead of `frontends/pytorch/e2e_testing/torchscript`. Any needed dependencies can easily be configured in generate_serialized_tests.sh. We add an initial machine translation model with a complex set of dependencies to seed the curriculum there. I verified that this model gets to the point of MLIR import (it fails there with a segfault due to not being able to import the "Any" type). This required moving a few files from the `torch_mlir` Python module into multiple modules to isolate the code that depends on our C++ extensions (which now live in `torch_mlir` and `torch_mlir_torchscript_e2e_test_configs`) from the pure Python code (which now lives in `torch_mlir_torchscript`). This is an entirely mechanical change, and lots of imports needed to be updated. The dependency graph is: ``` torch_mlir_torchscript_e2e_test_configs / | / | / | V V torch_mlir_torchscript torch_mlir ``` The `torch_mlir_torchscript_e2e_test_configs` are then dependency-injected into the `torch_mlir_torchscript` modules to successfully assemble a working test harness (the code was already structured this way, but this new file organization allows the isolation from C++ code to actually happen). This isolation is critical to allowing the serialized programs to be transported across PyTorch versions and for the test harness to be used seamlessly to generate the heavydep tests. Also: - Extend `_Tracer` class to support nested property (submodule) accesses. Recommended review order: - "user-level" docs in README.md - code in `build_tools/torchscript_e2e_heavydep_tests`. - changes in `torch_mlir_torchscript/e2e_test/framework.py` - misc mechanical changes.pull/268/head
parent
f168cacd6d
commit
453e29ea05
44
README.md
44
README.md
|
@ -129,11 +129,6 @@ export LDFLAGS=-fuse-ld=$(which ld.lld-$LLVM_VERSION)
|
||||||
### Vanilla - numpy-only, no pytorch
|
### Vanilla - numpy-only, no pytorch
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# Install PyTorch. We currently track and require the nighly build.
|
|
||||||
# If a usable PyTorch package is installed, the default cmake settings will
|
|
||||||
# enable the PyTorch frontend.
|
|
||||||
pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
|
||||||
|
|
||||||
# Configure npcomp.
|
# Configure npcomp.
|
||||||
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release .
|
cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release .
|
||||||
|
|
||||||
|
@ -147,6 +142,11 @@ ninja check-npcomp
|
||||||
### With PyTorch integration
|
### With PyTorch integration
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
# Install PyTorch. We currently track and require the nighly build.
|
||||||
|
# If a usable PyTorch package is installed, the default cmake settings will
|
||||||
|
# enable the PyTorch frontend.
|
||||||
|
pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
|
|
||||||
cmake -DNPCOMP_ENABLE_PYTORCH=ON ...
|
cmake -DNPCOMP_ENABLE_PYTORCH=ON ...
|
||||||
ninja check-frontends-pytorch # If building with PyTorch
|
ninja check-frontends-pytorch # If building with PyTorch
|
||||||
```
|
```
|
||||||
|
@ -210,6 +210,40 @@ echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env
|
||||||
tools/torchscript_e2e_test.sh --config=iree
|
tools/torchscript_e2e_test.sh --config=iree
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Additional end-to-end tests with heavy dependencies (heavydep tests)
|
||||||
|
|
||||||
|
Some end-to-end tests require additional dependencies which don't make sense to
|
||||||
|
include as part of the default npcomp setup. Additionally, these dependencies
|
||||||
|
often don't work with the same HEAD PyTorch version that npcomp builds against
|
||||||
|
at the C++ level.
|
||||||
|
|
||||||
|
We have a self-contained script that generates all the needed artifacts from a
|
||||||
|
self-contained virtual environment. It can be used like so:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Build the virtual environment in the specified directory and generate the
|
||||||
|
# serialized test artifacts in the other specified directory.
|
||||||
|
# This command is safe to re-run if you have already built the virtual
|
||||||
|
# environment and just changed the tests.
|
||||||
|
build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh \
|
||||||
|
path/to/heavydep_venv \
|
||||||
|
path/to/heavydep_serialized_tests
|
||||||
|
|
||||||
|
# Add the --serialized-test-dir flag to point at the directory containing the
|
||||||
|
# serialized tests. All other functionality is the same as the normal invocation
|
||||||
|
# of torchscript_e2e_test.sh, but the serialized tests will be available.
|
||||||
|
tools/torchscript_e2e_test.sh --serialized-test-dir=/t/heavydep_serialized_tests
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that the heavy dep tests are generally quite challenging, and we don't have
|
||||||
|
any that work yet. The tests use the same (pure-Python) test framework as the
|
||||||
|
normal torchscript_e2e_test.sh, but the tests are added in
|
||||||
|
`build_tools/torchscript_e2e_heavydep_tests` instead of
|
||||||
|
`frontends/pytorch/e2e_testing/torchscript`.
|
||||||
|
|
||||||
|
We rely critically on serialized TorchScript compatibility across PyTorch
|
||||||
|
versions to transport the tests + pure-Python compatibility of the `torch`
|
||||||
|
API, which has worked well so far.
|
||||||
|
|
||||||
### VSCode with a Docker Dev Image
|
### VSCode with a Docker Dev Image
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
# Basic machine translation (MT) program.
|
||||||
|
#
|
||||||
|
# This is an ID's to ID's end-to-end program
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairseq.data.dictionary import Dictionary
|
||||||
|
from fairseq.models.fairseq_encoder import EncoderOut
|
||||||
|
from fairseq.modules.multihead_attention import MultiheadAttention
|
||||||
|
from fairseq.models.transformer import TransformerModel
|
||||||
|
from fairseq.models import (
|
||||||
|
FairseqEncoder,
|
||||||
|
FairseqEncoderDecoderModel,
|
||||||
|
FairseqIncrementalDecoder,
|
||||||
|
)
|
||||||
|
from fairseq.sequence_generator import SequenceGenerator
|
||||||
|
from fairseq.tasks.fairseq_task import LegacyFairseqTask
|
||||||
|
from fairseq import utils
|
||||||
|
|
||||||
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
DEFAULT_TEST_VOCAB_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTask(LegacyFairseqTask):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__(args)
|
||||||
|
self.dictionary = _get_dummy_dictionary()
|
||||||
|
if getattr(self.args, "ctc", False):
|
||||||
|
self.dictionary.add_symbol("<ctc_blank>")
|
||||||
|
self.src_dict = self.dictionary
|
||||||
|
self.tgt_dict = self.dictionary
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_dictionary(self):
|
||||||
|
return self.src_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_dictionary(self):
|
||||||
|
return self.dictionary
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
|
||||||
|
dummy_dict = Dictionary()
|
||||||
|
for id, _ in enumerate(range(vocab_size)):
|
||||||
|
dummy_dict.add_symbol("{}".format(id), n=1000)
|
||||||
|
return dummy_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dummy_task_and_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="test_dummy_s2s_task",
|
||||||
|
argument_default=argparse.SUPPRESS)
|
||||||
|
DummyTask.add_args(parser)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
task = DummyTask.setup_task(args)
|
||||||
|
return task, parser
|
||||||
|
|
||||||
|
|
||||||
|
class BasicMtModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
task, parser = _get_dummy_task_and_parser()
|
||||||
|
TransformerModel.add_args(parser)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
args.encoder_layers = 2
|
||||||
|
args.decoder_layers = 1
|
||||||
|
transformer_model = TransformerModel.build_model(args, task)
|
||||||
|
self.sequence_generator = SequenceGenerator(
|
||||||
|
[transformer_model],
|
||||||
|
task.tgt_dict,
|
||||||
|
beam_size=2,
|
||||||
|
no_repeat_ngram_size=2,
|
||||||
|
max_len_b=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Support list/dict returns from functions.
|
||||||
|
# This will allow us to handle a variable number of sentences.
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, -1], torch.long, True),
|
||||||
|
([2], torch.long, True),
|
||||||
|
])
|
||||||
|
def forward(self, src_tokens, src_lengths):
|
||||||
|
sample = {
|
||||||
|
'net_input': {
|
||||||
|
'src_tokens': src_tokens,
|
||||||
|
'src_lengths': src_lengths
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = self.sequence_generator(sample)
|
||||||
|
return result[0][0]["tokens"], result[1][0]["tokens"]
|
||||||
|
|
||||||
|
EOS = BasicMtModule().sequence_generator.eos
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BasicMtModule())
|
||||||
|
def BasicMtModule_basic(module, tu: TestUtils):
|
||||||
|
# Imagine random sentences from the vocabulary. Use a subset of the
|
||||||
|
# vocabulary that doesn't overlap with the EOS (which is usually the number
|
||||||
|
# 2).
|
||||||
|
MAX_SENTENCE_LENGTH = 10
|
||||||
|
src_tokens = torch.randint(DEFAULT_TEST_VOCAB_SIZE // 4,
|
||||||
|
DEFAULT_TEST_VOCAB_SIZE // 3,
|
||||||
|
(2, MAX_SENTENCE_LENGTH)).long()
|
||||||
|
# Append end-of-sentence symbol to the end of each sentence.
|
||||||
|
src_tokens = torch.cat((src_tokens, torch.LongTensor([[EOS], [EOS]])), -1)
|
||||||
|
src_lengths = torch.LongTensor([7, 10])
|
||||||
|
module.forward(src_tokens, src_lengths)
|
|
@ -0,0 +1,28 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Check that only two arugments are passed
|
||||||
|
if [ "$#" -ne 2 ]; then
|
||||||
|
echo "Usage: $0 <venv_dir> <serialized_test_dir>"
|
||||||
|
echo 'Description:
|
||||||
|
venv_dir: directory to put the Python venv used for generating the serialized tests
|
||||||
|
serialized_test_dir: directory to write the generated serialized tests to
|
||||||
|
'
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
venv_dir=$1
|
||||||
|
serialized_test_dir=$2
|
||||||
|
here="$(realpath $(dirname $0))"
|
||||||
|
npcomp_src_root="$here/../../"
|
||||||
|
|
||||||
|
mkdir -p $venv_dir
|
||||||
|
mkdir -p $serialized_test_dir
|
||||||
|
python3 -m venv $venv_dir
|
||||||
|
source $venv_dir/bin/activate
|
||||||
|
python3 -m pip install fairseq fvcore sacremoses subword-nmt
|
||||||
|
|
||||||
|
cd "$npcomp_src_root"
|
||||||
|
export PYTHONPATH=${PYTHONPATH-}
|
||||||
|
source "$npcomp_src_root/.env"
|
||||||
|
python3 -m build_tools.torchscript_e2e_heavydep_tests.main --output_dir=$serialized_test_dir
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
from torch_mlir_torchscript.e2e_test.framework import SerializableTest, generate_golden_trace
|
||||||
|
from torch_mlir_torchscript.annotations import extract_serializable_annotations
|
||||||
|
|
||||||
|
from . import basic_mt
|
||||||
|
|
||||||
|
|
||||||
|
def _get_argparse():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate assets for TorchScript E2E tests")
|
||||||
|
parser.add_argument("--output_dir", help="The directory to put assets in.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = _get_argparse().parse_args()
|
||||||
|
serializable_tests = []
|
||||||
|
for test in GLOBAL_TEST_REGISTRY:
|
||||||
|
trace = generate_golden_trace(test)
|
||||||
|
module = torch.jit.script(test.program_factory())
|
||||||
|
torchscript_module_bytes = module.save_to_buffer({
|
||||||
|
"annotations.pkl":
|
||||||
|
pickle.dumps(extract_serializable_annotations(module))
|
||||||
|
})
|
||||||
|
serializable_tests.append(
|
||||||
|
SerializableTest(unique_name=test.unique_name,
|
||||||
|
program=torchscript_module_bytes,
|
||||||
|
trace=trace))
|
||||||
|
for test in serializable_tests:
|
||||||
|
with open(os.path.join(args.output_dir, f"{test.unique_name}.pkl"),
|
||||||
|
"wb") as f:
|
||||||
|
pickle.dump(test, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -4,9 +4,9 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class BatchNorm1DModule(torch.nn.Module):
|
class BatchNorm1DModule(torch.nn.Module):
|
||||||
|
|
|
@ -3,9 +3,9 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# TODO: Support scalar !torch.int/!torch.float variants. Add support to
|
# TODO: Support scalar !torch.int/!torch.float variants. Add support to
|
||||||
# ReduceOpVariants to implement them in terms of the tensor-only variants +
|
# ReduceOpVariants to implement them in terms of the tensor-only variants +
|
||||||
|
|
|
@ -3,15 +3,17 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests
|
from torch_mlir_torchscript.e2e_test.framework import run_tests
|
||||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
from torch_mlir_torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
from torch_mlir_torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
# Available test configs.
|
# Available test configs.
|
||||||
from torch_mlir.torchscript.e2e_test.configs import (
|
from torch_mlir_torchscript_e2e_test_configs import (
|
||||||
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,6 +58,14 @@ Regular expression specifying which tests to include in this run.
|
||||||
default=False,
|
default=False,
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='report test results with additional detail')
|
help='report test results with additional detail')
|
||||||
|
parser.add_argument('--serialized-test-dir', default=None, type=str, help='''
|
||||||
|
The directory containing serialized pre-built tests.
|
||||||
|
Right now, these are additional tests which require heavy Python dependencies
|
||||||
|
to generate (or cannot even be generated with the version of PyTorch used by
|
||||||
|
npcomp).
|
||||||
|
See `build_tools/torchscript_e2e_heavydep_tests/generate_serialized_tests.sh`
|
||||||
|
for more information on building these artifacts.
|
||||||
|
''')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -71,9 +81,16 @@ def main():
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
|
|
||||||
|
all_tests = list(GLOBAL_TEST_REGISTRY)
|
||||||
|
if args.serialized_test_dir:
|
||||||
|
for root, dirs, files in os.walk(args.serialized_test_dir):
|
||||||
|
for filename in files:
|
||||||
|
with open(os.path.join(root, filename), 'rb') as f:
|
||||||
|
all_tests.append(pickle.load(f).as_test())
|
||||||
|
|
||||||
# Find the selected tests, and emit a diagnostic if none are found.
|
# Find the selected tests, and emit a diagnostic if none are found.
|
||||||
tests = [
|
tests = [
|
||||||
test for test in GLOBAL_TEST_REGISTRY
|
test for test in all_tests
|
||||||
if re.match(args.filter, test.unique_name)
|
if re.match(args.filter, test.unique_name)
|
||||||
]
|
]
|
||||||
if len(tests) == 0:
|
if len(tests) == 0:
|
||||||
|
@ -81,7 +98,7 @@ def main():
|
||||||
f'ERROR: the provided filter {args.filter!r} does not match any tests'
|
f'ERROR: the provided filter {args.filter!r} does not match any tests'
|
||||||
)
|
)
|
||||||
print('The available tests are:')
|
print('The available tests are:')
|
||||||
for test in GLOBAL_TEST_REGISTRY:
|
for test in all_tests:
|
||||||
print(test.unique_name)
|
print(test.unique_name)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case
|
||||||
from torch_mlir.torchscript.annotations import annotate_args, export
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -16,58 +16,18 @@ import torch_mlir
|
||||||
# to be expressed conveniently and gives clearer error reports when
|
# to be expressed conveniently and gives clearer error reports when
|
||||||
# the annotations aren't acceptable.
|
# the annotations aren't acceptable.
|
||||||
|
|
||||||
|
# This module is kept separate from torch_mlir_torchscript.annotations so that
|
||||||
def export(fn):
|
# we can use that module from code without C++ dependencies, which prevent us
|
||||||
"""Decorator that tells the npcomp compiler that a method is exported.
|
# from interfacing the test framework across environments.
|
||||||
|
|
||||||
By default, no methods are exported, which is very important for
|
|
||||||
the compiler, because otherwise most Torch programs consist of a sea
|
|
||||||
of tiny exported functions with no rank or dtype information
|
|
||||||
(see `annotate_args`), which the compiler cannot do much with.
|
|
||||||
|
|
||||||
Note that this is different from `torch.jit.export`, which controls
|
|
||||||
which methods are scripted in the first place. For non-`forward` methods,
|
|
||||||
using this decorator usually means you also need `torch.jit.export`.
|
|
||||||
Conceptually, this decorator is annotating the scripted module, but is
|
|
||||||
applied to the original `torch.nn.Module` for convenience.
|
|
||||||
"""
|
|
||||||
fn._npcomp_export = True
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
|
||||||
ArgAnnotation = Tuple[List[int], torch.dtype]
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Replace with py3 extended argument annotations when available.
|
|
||||||
# See https://www.python.org/dev/peps/pep-0593/
|
|
||||||
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
|
||||||
"""Decorator that tells the npcomp compiler information about arguments.
|
|
||||||
|
|
||||||
The `annotations` should be a list of the same length as the number of
|
|
||||||
argument to the method (including `self`). Each list entry is either:
|
|
||||||
- None, corresponding to providing the compiler with no information.
|
|
||||||
- A 2-tuple consisting of a shape and a dtype, such as
|
|
||||||
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
|
|
||||||
indicated by using `-1` as the size. This provides the compiler a
|
|
||||||
guarantee that the argument will always dynamically have the described
|
|
||||||
shape and dtype.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: Check the number of arguments matches the number of arg annotations.
|
|
||||||
def decorator(fn):
|
|
||||||
fn._npcomp_arg_annotations = annotations
|
|
||||||
return fn
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
# Utilities for extracting decorated information into torch_mlir.ClassAnnotator.
|
# Utilities for extracting decorated information into torch_mlir.ClassAnnotator.
|
||||||
|
|
||||||
|
|
||||||
def _recursively_extract_annotations(
|
def _recursively_extract_annotations(
|
||||||
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
|
module: torch.nn.Module, scripted: torch.jit.ScriptModule,
|
||||||
class_annotator: torch_mlir.ClassAnnotator):
|
class_annotator: torch_mlir.ClassAnnotator):
|
||||||
assert module.__class__.__name__ == scripted.original_name, "script module does not come from specified module"
|
assert module.__class__.__name__ == scripted.original_name or (
|
||||||
|
isinstance(module, torch.jit.RecursiveScriptModule) and module is
|
||||||
|
scripted), "script module does not come from specified module"
|
||||||
|
|
||||||
# Extract information on methods.
|
# Extract information on methods.
|
||||||
for method_name, scripted_method in scripted.__dict__.items():
|
for method_name, scripted_method in scripted.__dict__.items():
|
|
@ -0,0 +1,128 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Decorators
|
||||||
|
|
||||||
|
# Currently, these decorators are very low-level and map 1:1 with
|
||||||
|
# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to
|
||||||
|
# be a more elaborate Python layer which allows all the different annotations
|
||||||
|
# to be expressed conveniently and gives clearer error reports when
|
||||||
|
# the annotations aren't acceptable.
|
||||||
|
|
||||||
|
# This module is kept separate from torch_mlir.torchscript_annotations so that
|
||||||
|
# we can use this from code without C++ dependencies, which prevent us from
|
||||||
|
# interfacing the test framework across environments.
|
||||||
|
|
||||||
|
# Attribute names used for annotations.
|
||||||
|
# These should be kept in sync with their use in
|
||||||
|
# `frontends/pytorch/python/torch_mlir/torchscript_annotations.py`.
|
||||||
|
NPCOMP_EXPORT_ATTR_NAME = '_npcomp_export'
|
||||||
|
NPCOMP_ARG_ANNOTATIONS_ATTR_NAME = '_npcomp_arg_annotations'
|
||||||
|
|
||||||
|
|
||||||
|
def export(fn):
|
||||||
|
"""Decorator that tells the npcomp compiler that a method is exported.
|
||||||
|
|
||||||
|
By default, no methods are exported, which is very important for
|
||||||
|
the compiler, because otherwise most Torch programs consist of a sea
|
||||||
|
of tiny exported functions with no rank or dtype information
|
||||||
|
(see `annotate_args`), which the compiler cannot do much with.
|
||||||
|
|
||||||
|
Note that this is different from `torch.jit.export`, which controls
|
||||||
|
which methods are scripted in the first place. For non-`forward` methods,
|
||||||
|
using this decorator usually means you also need `torch.jit.export`.
|
||||||
|
Conceptually, this decorator is annotating the scripted module, but is
|
||||||
|
applied to the original `torch.nn.Module` for convenience.
|
||||||
|
"""
|
||||||
|
setattr(fn, NPCOMP_EXPORT_ATTR_NAME, True)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
ArgAnnotation = Tuple[List[int], torch.dtype]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Replace with py3 extended argument annotations when available.
|
||||||
|
# See https://www.python.org/dev/peps/pep-0593/
|
||||||
|
def annotate_args(annotations: List[Optional[ArgAnnotation]]):
|
||||||
|
"""Decorator that tells the npcomp compiler information about arguments.
|
||||||
|
|
||||||
|
The `annotations` should be a list of the same length as the number of
|
||||||
|
argument to the method (including `self`). Each list entry is either:
|
||||||
|
- None, corresponding to providing the compiler with no information.
|
||||||
|
- A 2-tuple consisting of a shape and a dtype, such as
|
||||||
|
`([2, 3, 4], torch.float32)`. A dimension with an unknown size can be
|
||||||
|
indicated by using `-1` as the size. This provides the compiler a
|
||||||
|
guarantee that the argument will always dynamically have the described
|
||||||
|
shape and dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Check the number of arguments matches the number of arg annotations.
|
||||||
|
def decorator(fn):
|
||||||
|
setattr(fn, NPCOMP_ARG_ANNOTATIONS_ATTR_NAME, annotations)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class SerializableMethodAnnotation(NamedTuple):
|
||||||
|
method_name: str
|
||||||
|
export: Optional[bool]
|
||||||
|
arg_annotations: Optional[List[ArgAnnotation]]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializableModuleAnnotations(NamedTuple):
|
||||||
|
method_annotations: List[SerializableMethodAnnotation]
|
||||||
|
submodule_annotations: List[Tuple[str, "SerializableModuleAnnotations"]]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_serializable_annotations(
|
||||||
|
module: torch.nn.Module) -> SerializableModuleAnnotations:
|
||||||
|
module_annotations = SerializableModuleAnnotations(
|
||||||
|
method_annotations=[], submodule_annotations=[])
|
||||||
|
# Extract information on methods.
|
||||||
|
for method_name, method in module.__dict__.items():
|
||||||
|
# See if it is a method.
|
||||||
|
if not callable(method):
|
||||||
|
continue
|
||||||
|
export = None
|
||||||
|
arg_annotations = None
|
||||||
|
if hasattr(method, NPCOMP_EXPORT_ATTR_NAME):
|
||||||
|
export = method._npcomp_export
|
||||||
|
if hasattr(method, NPCOMP_ARG_ANNOTATIONS_ATTR_NAME):
|
||||||
|
arg_annotations = method._npcomp_arg_annotations
|
||||||
|
if export is not None and arg_annotations is not None:
|
||||||
|
module_annotations.method_annotations.append(
|
||||||
|
SerializableMethodAnnotation(method_name=method_name,
|
||||||
|
export=export,
|
||||||
|
arg_annotations=arg_annotations))
|
||||||
|
|
||||||
|
# Recurse.
|
||||||
|
for name, child in module.named_children():
|
||||||
|
annotations = extract_serializable_annotations(child)
|
||||||
|
module_annotations.submodule_annotations.append((name, annotations))
|
||||||
|
return module_annotations
|
||||||
|
|
||||||
|
|
||||||
|
def apply_serializable_annotations(module: torch.nn.Module,
|
||||||
|
annotations: SerializableModuleAnnotations):
|
||||||
|
# Apply annotations to methods.
|
||||||
|
for method_annotation in annotations.method_annotations:
|
||||||
|
# Imitate use of the decorators to keep a source of truth there.
|
||||||
|
if method_annotation.export is not None:
|
||||||
|
setattr(module, method_annotation.method_name,
|
||||||
|
export(getattr(module, method_annotation.method_name)))
|
||||||
|
if method_annotation.arg_annotations is not None:
|
||||||
|
setattr(
|
||||||
|
module, method_annotation.method_name,
|
||||||
|
annotate_args(method_annotation.arg_annotations)(getattr(
|
||||||
|
module, method_annotation.method_name)))
|
||||||
|
|
||||||
|
# Recurse.
|
||||||
|
for name, submodule_annotations in annotations.submodule_annotations:
|
||||||
|
child = getattr(module, name)
|
||||||
|
apply_serializable_annotations(child, submodule_annotations)
|
|
@ -22,8 +22,13 @@ compiling or TorchScript'ing).
|
||||||
import abc
|
import abc
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
|
from typing import Any, Callable, List, NamedTuple, Optional, TypeVar
|
||||||
|
|
||||||
|
import io
|
||||||
|
import pickle
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..annotations import apply_serializable_annotations
|
||||||
|
|
||||||
|
|
||||||
class TraceItem(NamedTuple):
|
class TraceItem(NamedTuple):
|
||||||
# The externally visible symbol name that is called.
|
# The externally visible symbol name that is called.
|
||||||
|
@ -162,6 +167,57 @@ class Test(NamedTuple):
|
||||||
program_invoker: Callable[[Any, TestUtils], None]
|
program_invoker: Callable[[Any, TestUtils], None]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializableTest(NamedTuple):
|
||||||
|
"""A self-contained representation of a test that can be pickled.
|
||||||
|
|
||||||
|
We use serialized TorchScript programs here for two reasons:
|
||||||
|
1. The PyTorch pickling story isn't great, so in order to reliably pickle
|
||||||
|
this class, we rely on having the serialized bytes for the TorchScript
|
||||||
|
module already given to us.
|
||||||
|
2. The choice of a TorchScript module vs `torch.nn.Module` boils down to
|
||||||
|
the fact that `torch.nn.Module` cannot be deserialized without pulling
|
||||||
|
in the same set of Python dependencies that were used to serialize it
|
||||||
|
in the first place. This would defeat one of the
|
||||||
|
main use cases of this class, which is to transport a test from an
|
||||||
|
environment with a set of heavy dependencies to a dependency-light one.
|
||||||
|
Since TorchScript modules are self-contained, they fit the bill
|
||||||
|
perfectly.
|
||||||
|
"""
|
||||||
|
# See unique_name on `Test`.
|
||||||
|
unique_name: str
|
||||||
|
# Serialized TorchScript program.
|
||||||
|
program: bytes
|
||||||
|
# Trace for execution testing.
|
||||||
|
trace: Trace
|
||||||
|
|
||||||
|
def as_test(self) -> Test:
|
||||||
|
"""Create a `Test` from this class."""
|
||||||
|
# Conform the serialized program to the interface expected by Test.
|
||||||
|
# This is a bit of a hack, but it's the only way to keep the layering
|
||||||
|
# straight.
|
||||||
|
def factory():
|
||||||
|
_extra_files = {"annotations.pkl": ""}
|
||||||
|
module = torch.jit.load(io.BytesIO(self.program),
|
||||||
|
_extra_files=_extra_files)
|
||||||
|
# Load the pickled annotations.
|
||||||
|
annotations = pickle.loads(_extra_files["annotations.pkl"])
|
||||||
|
apply_serializable_annotations(module, annotations)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def invoker(module, tu):
|
||||||
|
for item in self.trace:
|
||||||
|
attr = module
|
||||||
|
for part in item.symbol.split("."):
|
||||||
|
attr = getattr(attr, part)
|
||||||
|
attr(*item.inputs)
|
||||||
|
|
||||||
|
return Test(
|
||||||
|
unique_name=self.unique_name,
|
||||||
|
program_factory=factory,
|
||||||
|
program_invoker=invoker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestResult(NamedTuple):
|
class TestResult(NamedTuple):
|
||||||
# Stable unique name for error reporting and test suite configuration.
|
# Stable unique name for error reporting and test suite configuration.
|
||||||
#
|
#
|
||||||
|
@ -188,39 +244,44 @@ class TestResult(NamedTuple):
|
||||||
class _Tracer:
|
class _Tracer:
|
||||||
"""Wrapper around a `torch.nn.Module` that records calls into it.
|
"""Wrapper around a `torch.nn.Module` that records calls into it.
|
||||||
|
|
||||||
The inputs and outputs of each call are recorded in a Trace.
|
The inputs and outputs of each call are recorded in a Trace. Recursive
|
||||||
|
property accesses are also traced.
|
||||||
"""
|
"""
|
||||||
module: torch.nn.Module
|
def __init__(self, wrapped, property_base_path: List[str], trace: Trace):
|
||||||
trace: Trace
|
self.__wrapped__ = wrapped
|
||||||
|
self.__trace__ = trace
|
||||||
|
self.__property_base_path__ = property_base_path
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module):
|
def __call__(self, *args, **kwargs):
|
||||||
self.module = module
|
raw_outputs = self.__wrapped__(*args, **kwargs)
|
||||||
self.trace = []
|
if isinstance(raw_outputs, torch.Tensor):
|
||||||
|
outputs = [raw_outputs]
|
||||||
|
elif isinstance(raw_outputs, tuple) and all(
|
||||||
|
isinstance(o, torch.Tensor) for o in raw_outputs):
|
||||||
|
outputs = raw_outputs
|
||||||
|
else:
|
||||||
|
raise Exception("unimplemented: non-Tensor output from function")
|
||||||
|
self.__trace__.append(
|
||||||
|
TraceItem(symbol=".".join(self.__property_base_path__),
|
||||||
|
inputs=args,
|
||||||
|
outputs=outputs))
|
||||||
|
return raw_outputs
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
# TODO: Handle `module.foo.bar.baz` nesting.
|
return _Tracer(getattr(self.__wrapped__, name),
|
||||||
# For now, we are limited to attributes of the top-level module.
|
self.__property_base_path__ + [name], self.__trace__)
|
||||||
def invoke(*args):
|
|
||||||
raw_outputs = getattr(self.module, name)(*args)
|
|
||||||
if isinstance(raw_outputs, torch.Tensor):
|
|
||||||
outputs = [raw_outputs]
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"unimplemented: non-Tensor output from function")
|
|
||||||
self.trace.append(
|
|
||||||
TraceItem(symbol=name, inputs=args, outputs=outputs))
|
|
||||||
return raw_outputs
|
|
||||||
|
|
||||||
return invoke
|
|
||||||
|
|
||||||
def get_trace(self):
|
|
||||||
return self.trace
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_golden_trace(test: Test) -> Trace:
|
def generate_golden_trace(test: Test) -> Trace:
|
||||||
tracer = _Tracer(test.program_factory())
|
"""Generate a trace with the original program.
|
||||||
|
|
||||||
|
If the original program is deterministic, then this the produced trace is
|
||||||
|
suitable as a golden trace to compare against.
|
||||||
|
"""
|
||||||
|
trace = []
|
||||||
|
tracer = _Tracer(test.program_factory(), [], trace)
|
||||||
test.program_invoker(tracer, TestUtils())
|
test.program_invoker(tracer, TestUtils())
|
||||||
return tracer.get_trace()
|
return trace
|
||||||
|
|
||||||
|
|
||||||
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||||
|
@ -229,9 +290,16 @@ def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]:
|
||||||
for test in tests:
|
for test in tests:
|
||||||
# TODO: Precompile everything in parallel.
|
# TODO: Precompile everything in parallel.
|
||||||
try:
|
try:
|
||||||
golden_trace = _generate_golden_trace(test)
|
golden_trace = generate_golden_trace(test)
|
||||||
compiled = config.compile(test.program_factory())
|
compiled = config.compile(test.program_factory())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Useful for debugging:
|
||||||
|
# ```
|
||||||
|
# raise
|
||||||
|
# ```
|
||||||
|
# This will give the full traceback rather than giving just
|
||||||
|
# the stringified exception in the report.
|
||||||
|
# TODO: Capture the traceback and make it available in the report.
|
||||||
results.append(
|
results.append(
|
||||||
TestResult(unique_name=test.unique_name,
|
TestResult(unique_name=test.unique_name,
|
||||||
compilation_error=str(e),
|
compilation_error=str(e),
|
|
@ -161,7 +161,7 @@ class SingleTestReport:
|
||||||
f = io.StringIO()
|
f = io.StringIO()
|
||||||
p = lambda *x: print(*x, file=f)
|
p = lambda *x: print(*x, file=f)
|
||||||
if self.result.compilation_error is not None:
|
if self.result.compilation_error is not None:
|
||||||
return 'compilation error' + self.result.compilation_error
|
return 'Compilation error: ' + self.result.compilation_error
|
||||||
for report in self.item_reports:
|
for report in self.item_reports:
|
||||||
if report.failed:
|
if report.failed:
|
||||||
p(report.error_str())
|
p(report.error_str())
|
|
@ -7,7 +7,7 @@ from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
|
|
||||||
class NativeTorchTestConfig(TestConfig):
|
class NativeTorchTestConfig(TestConfig):
|
|
@ -15,8 +15,8 @@ from mlir.passmanager import PassManager
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
from npcomp.compiler.pytorch.backend import refjit
|
from npcomp.compiler.pytorch.backend import refjit
|
||||||
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
from torch_mlir.torchscript.annotations import extract_annotations
|
from torch_mlir.torchscript_annotations import extract_annotations
|
||||||
|
|
||||||
class PrettyErrorReportForIrOperation(object):
|
class PrettyErrorReportForIrOperation(object):
|
||||||
def __init__(self, module, module_name_for_ir_dump: str):
|
def __init__(self, module, module_name_for_ir_dump: str):
|
|
@ -7,7 +7,7 @@ from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir_torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
|
||||||
|
|
||||||
class TorchScriptTestConfig(TestConfig):
|
class TorchScriptTestConfig(TestConfig):
|
||||||
|
@ -24,7 +24,10 @@ class TorchScriptTestConfig(TestConfig):
|
||||||
|
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
outputs = getattr(artifact, item.symbol)(*item.inputs)
|
attr = artifact
|
||||||
|
for part in item.symbol.split('.'):
|
||||||
|
attr = getattr(attr, part)
|
||||||
|
outputs = attr(*item.inputs)
|
||||||
if isinstance(outputs, torch.Tensor):
|
if isinstance(outputs, torch.Tensor):
|
||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
result.append(
|
result.append(
|
|
@ -40,6 +40,11 @@ setup(
|
||||||
"": "./python",
|
"": "./python",
|
||||||
},
|
},
|
||||||
packages=find_packages("./python", include=[
|
packages=find_packages("./python", include=[
|
||||||
"torch_mlir", "torch_mlir.*",
|
"torch_mlir",
|
||||||
|
"torch_mlir.*",
|
||||||
|
"torch_mlir_torchscript",
|
||||||
|
"torch_mlir_torchscript.*",
|
||||||
|
"torch_mlir_torchscript_e2e_test_configs",
|
||||||
|
"torch_mlir_torchscript_e2e_test_configs.*",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,9 +7,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
from torch_mlir.torchscript.annotations import (
|
from torch_mlir_torchscript.annotations import annotate_args, export
|
||||||
annotate_args, export, extract_annotations
|
from torch_mlir.torchscript_annotations import extract_annotations
|
||||||
)
|
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
from torch_mlir_torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
from torch_mlir_torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
|
@ -26,7 +26,7 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# CHECK: FAIL - "MmModule_basic"
|
# CHECK: FAIL - "MmModule_basic"
|
||||||
# CHECK: compilation error
|
# CHECK: Compilation error:
|
||||||
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
||||||
# the offending "return 3".
|
# the offending "return 3".
|
||||||
# CHECK: return 3
|
# CHECK: return 3
|
||||||
|
|
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch_mlir.torchscript.e2e_test.framework import run_tests, TestUtils
|
from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
from torch_mlir.torchscript.e2e_test.reporting import report_results
|
from torch_mlir_torchscript.e2e_test.reporting import report_results
|
||||||
from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig
|
from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_torchscript.e2e_test.framework import run_tests, TestUtils
|
||||||
|
from torch_mlir_torchscript.e2e_test.reporting import report_results
|
||||||
|
from torch_mlir_torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY
|
||||||
|
from torch_mlir_torchscript_e2e_test_configs import TorchScriptTestConfig
|
||||||
|
|
||||||
|
class Submodule2(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.mm(lhs, rhs)
|
||||||
|
|
||||||
|
class Submodule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.m2 = Submodule2()
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleWithSubmodule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.m = Submodule()
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK: PASS - "ModuleWithSubmodule_basic"
|
||||||
|
@register_test_case(module_factory=lambda: ModuleWithSubmodule())
|
||||||
|
def ModuleWithSubmodule_basic(module, tu: TestUtils):
|
||||||
|
module.m.m2.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = TorchScriptTestConfig()
|
||||||
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
|
report_results(results, set())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue