mirror of https://github.com/llvm/torch-mlir
Add IREE support in TorchScript e2e tests.
- Add support for "expected failures" in test reporting. The new error reports look like [this](https://gist.github.com/silvasean/6ffd95e1d55302b699673da201da210d). - We will now be able to put these tests into CI, since the harness understand which tests are expected to pass and fail. - Refactor RefBackendTestConfig to NpcompBackendTestConfig which supports both RefBackend and IREE. - Add instructions for installing IREE dependencies (both from packages and for local builds of IREE) - Add `tools/torchscript_e2e_test.sh` for invoking the e2e test harness (this makes invoking a bit easier, as it doesn't rely on a loose Python invocation).pull/243/head
parent
79928cd2dd
commit
d5108b9dc1
31
README.md
31
README.md
|
@ -164,6 +164,37 @@ cd /src/mlir-npcomp
|
||||||
cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch
|
cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### IREE Backend (from IREE packages)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# We currently track and require the latest snapshot.
|
||||||
|
pip3 install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Run TorchScript E2E tests targeting IREE.
|
||||||
|
# Make sure to run "PyTorch Frontend" setup instructions first.
|
||||||
|
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
|
||||||
|
```
|
||||||
|
|
||||||
|
### IREE Backend (from local IREE build)
|
||||||
|
|
||||||
|
This configuration is useful for iterating locally, as you can
|
||||||
|
poke/debug/rebuild things in IREE.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# Locally build IREE.
|
||||||
|
# See https://google.github.io/iree/building-from-source/getting-started/
|
||||||
|
# Make sure IREE is configured with `-DIREE_BUILD_PYTHON_BINDINGS=ON`.
|
||||||
|
|
||||||
|
echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env
|
||||||
|
|
||||||
|
# Run TorchScript E2E tests targeting IREE.
|
||||||
|
# Make sure to run "PyTorch Frontend" setup instructions first.
|
||||||
|
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### VSCode with a Docker Dev Image
|
### VSCode with a Docker Dev Image
|
||||||
|
|
||||||
#### Start a docker dev container based on our image
|
#### Start a docker dev container based on our image
|
||||||
|
|
|
@ -12,31 +12,39 @@ from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
# Available test configs.
|
# Available test configs.
|
||||||
from torch_mlir.torchscript.e2e_test.configs import (
|
from torch_mlir.torchscript.e2e_test.configs import (
|
||||||
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from npcomp.compiler.pytorch.backend import is_iree_enabled
|
||||||
|
IREE_ENABLED = is_iree_enabled()
|
||||||
|
if IREE_ENABLED:
|
||||||
|
from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend
|
||||||
|
from npcomp.compiler.pytorch.backend.refjit import RefjitNpcompBackend
|
||||||
|
|
||||||
|
from .xfail_sets import XFAIL_SETS
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
# TODO: Use a relative import.
|
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||||
# That requires invoking this file as a "package" though, which makes it
|
# this script.
|
||||||
# not possible to just do `python main.py`. Instead, it requires something
|
from . import basic
|
||||||
# like `python -m tochscript.main` which is annoying because it can only
|
from . import vision_models
|
||||||
# be run from a specific directory.
|
from . import mlp
|
||||||
# TODO: Find out best practices for python "main" files.
|
from . import batchnorm
|
||||||
import basic
|
from . import quantized_models
|
||||||
import vision_models
|
from . import elementwise
|
||||||
import mlp
|
|
||||||
import batchnorm
|
|
||||||
import quantized_models
|
|
||||||
import elementwise
|
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
|
config_choices = ['native_torch', 'torchscript', 'refbackend']
|
||||||
|
if IREE_ENABLED:
|
||||||
|
config_choices += ['iree']
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('--config',
|
parser.add_argument('--config',
|
||||||
choices=['native_torch', 'torchscript', 'refbackend'],
|
choices=config_choices,
|
||||||
default='refbackend',
|
default='refbackend',
|
||||||
help='''
|
help=f'''
|
||||||
Meaning of options:
|
Meaning of options:
|
||||||
"refbackend": run through npcomp's RefBackend.
|
"refbackend": run through npcomp's RefBackend.
|
||||||
|
"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
''')
|
''')
|
||||||
|
@ -54,7 +62,9 @@ def main():
|
||||||
|
|
||||||
# Find the selected config.
|
# Find the selected config.
|
||||||
if args.config == 'refbackend':
|
if args.config == 'refbackend':
|
||||||
config = RefBackendTestConfig()
|
config = NpcompBackendTestConfig(RefjitNpcompBackend())
|
||||||
|
elif args.config == 'iree':
|
||||||
|
config = NpcompBackendTestConfig(IreeNpcompBackend())
|
||||||
elif args.config == 'native_torch':
|
elif args.config == 'native_torch':
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
|
@ -78,7 +88,8 @@ def main():
|
||||||
results = run_tests(tests, config)
|
results = run_tests(tests, config)
|
||||||
|
|
||||||
# Report the test results.
|
# Report the test results.
|
||||||
report_results(results, args.verbose)
|
failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
|
||||||
|
sys.exit(1 if failed else 0)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
# This file describes the sets of tests expected to fail for each config.
|
||||||
|
# This information is deliberately kept in a side table, rather than
|
||||||
|
# in-situ on the test, as a deliberate layering decision: tests should
|
||||||
|
# have unique keys to identify them and enable side tables of various kinds
|
||||||
|
# (this includes down into lower parts of the stack, where a side table
|
||||||
|
# might be used to keep more elaborate sets of testing configurations).
|
||||||
|
|
||||||
|
XFAIL_SETS = {}
|
||||||
|
|
||||||
|
# Lists of tests that fail to even reach the backends.
|
||||||
|
# These represent further work needed in npcomp to lower them properly
|
||||||
|
# to the backend contract.
|
||||||
|
_common_npcomp_lowering_xfails = {
|
||||||
|
'ResNet18Module_basic',
|
||||||
|
'QuantizedMLP_basic',
|
||||||
|
}
|
||||||
|
|
||||||
|
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
|
||||||
|
|
||||||
|
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
|
||||||
|
# https://github.com/google/iree/issues/6368
|
||||||
|
'MmDagModule_basic',
|
||||||
|
'Mlp1LayerModule_basic',
|
||||||
|
'Mlp2LayerModule_basic',
|
||||||
|
}
|
|
@ -21,7 +21,7 @@ with mb.capture_function("cos", [input]) as f:
|
||||||
result = torch.cos(input)
|
result = torch.cos(input)
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||||
|
|
||||||
logging.debug(f"Executing jit_module.cos")
|
logging.debug(f"Executing jit_module.cos")
|
||||||
|
|
|
@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
|
||||||
with mb.capture_function("test", [arg0, arg1]) as f:
|
with mb.capture_function("test", [arg0, arg1]) as f:
|
||||||
f.returns([fun(arg0, arg1)])
|
f.returns([fun(arg0, arg1)])
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||||
|
|
||||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
||||||
|
|
|
@ -22,7 +22,7 @@ with mb.capture_function("mm", [lhs, rhs]) as f:
|
||||||
result = torch.mm(lhs, rhs)
|
result = torch.mm(lhs, rhs)
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||||
|
|
||||||
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
|
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
|
||||||
|
|
|
@ -28,7 +28,7 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
|
||||||
result = mul_maximum(lhs, rhs, threshold, bias)
|
result = mul_maximum(lhs, rhs, threshold, bias)
|
||||||
f.returns([result])
|
f.returns([result])
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||||
|
|
||||||
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,
|
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,
|
||||||
|
|
|
@ -26,7 +26,7 @@ mb = torch_mlir.ModuleBuilder()
|
||||||
with mb.capture_function("test", [arg0]) as f:
|
with mb.capture_function("test", [arg0]) as f:
|
||||||
f.returns([fun(arg0)])
|
f.returns([fun(arg0)])
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||||
|
|
||||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
||||||
|
|
|
@ -48,7 +48,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||||
#mb.module.operation.print()
|
#mb.module.operation.print()
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||||
jit_module = backend.load(compiled)
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||||
#mb.module.operation.print()
|
#mb.module.operation.print()
|
||||||
|
|
||||||
backend = refjit.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||||
jit_module = backend.load(compiled)
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
|
|
|
@ -34,13 +34,13 @@ class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||||
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||||
None,
|
None,
|
||||||
([2, 3, -1], torch.float32)
|
([2, 3, -1], torch.float32, True)
|
||||||
])
|
])
|
||||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||||
#mb.module.operation.print()
|
#mb.module.operation.print()
|
||||||
|
|
||||||
backend = iree.CompilerBackend()
|
backend = iree.IreeNpcompBackend()
|
||||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||||
jit_module = backend.load(compiled)
|
jit_module = backend.load(compiled)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,6 @@
|
||||||
# See https://llvm.org/LICENSE.txt for license information.
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
from .ref_backend import RefBackendTestConfig
|
from .npcomp_backend import NpcompBackendTestConfig
|
||||||
from .native_torch import NativeTorchTestConfig
|
from .native_torch import NativeTorchTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
|
|
|
@ -14,15 +14,32 @@ from mlir.passmanager import PassManager
|
||||||
|
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
from npcomp.compiler.pytorch.backend import refjit
|
from npcomp.compiler.pytorch.backend import refjit
|
||||||
|
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
||||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
from torch_mlir.torchscript.annotations import extract_annotations
|
from torch_mlir.torchscript.annotations import extract_annotations
|
||||||
|
|
||||||
|
class PrettyErrorReportForIrOperation(object):
|
||||||
|
def __init__(self, module, module_name_for_ir_dump: str):
|
||||||
|
sys.stderr = StringIO()
|
||||||
|
self.filename_for_ir_dump = os.path.join(tempfile.gettempdir(),
|
||||||
|
module_name_for_ir_dump + '.mlir')
|
||||||
|
self.asm_for_error_report = module.get_asm(
|
||||||
|
large_elements_limit=10, enable_debug_info=True)
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
with open(self.filename_for_ir_dump, 'w') as f:
|
||||||
|
f.write(self.asm_for_error_report)
|
||||||
|
|
||||||
class RefBackendTestConfig(TestConfig):
|
class NpcompBackendTestConfig(TestConfig):
|
||||||
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
|
"""Base class for TestConfig's that are implemented with npcomp.
|
||||||
def __init__(self):
|
|
||||||
|
This class handles all the common lowering that npcomp does before reaching
|
||||||
|
its backends.
|
||||||
|
"""
|
||||||
|
def __init__(self, backend: NpcompBackend):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backend = refjit.CompilerBackend()
|
self.backend = backend
|
||||||
|
|
||||||
def compile(self, program: torch.nn.Module) -> Any:
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
@ -79,14 +96,36 @@ $ npcomp-opt -{pipeline_str} {filename}
|
||||||
""") from None
|
""") from None
|
||||||
finally:
|
finally:
|
||||||
sys.stderr = sys.__stderr__
|
sys.stderr = sys.__stderr__
|
||||||
|
try:
|
||||||
|
sys.stderr = StringIO()
|
||||||
|
asm_for_error_report = mb.module.operation.get_asm(
|
||||||
|
large_elements_limit=10, enable_debug_info=True)
|
||||||
return self.backend.compile(mb.module)
|
return self.backend.compile(mb.module)
|
||||||
|
except Exception as e:
|
||||||
|
filename = os.path.join(tempfile.gettempdir(),
|
||||||
|
scripted.original_name + '.mlir')
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
f.write(asm_for_error_report)
|
||||||
|
raise Exception(f"""
|
||||||
|
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||||
|
## Exception:
|
||||||
|
{e}
|
||||||
|
|
||||||
|
## Stderr:
|
||||||
|
{sys.stderr.getvalue()}
|
||||||
|
|
||||||
|
## Input IR has been saved in {filename}
|
||||||
|
""") from None
|
||||||
|
finally:
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
|
||||||
|
|
||||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
jit_module = self.backend.load(artifact)
|
backend_module = self.backend.load(artifact)
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
numpy_inputs = [t.numpy() for t in item.inputs]
|
numpy_inputs = [t.numpy() for t in item.inputs]
|
||||||
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
|
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||||
if isinstance(outputs, np.ndarray):
|
if isinstance(outputs, np.ndarray):
|
||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
|
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
|
|
@ -5,8 +5,9 @@
|
||||||
Utilities for reporting the results of the test framework.
|
Utilities for reporting the results of the test framework.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Set
|
||||||
|
|
||||||
|
import collections
|
||||||
import io
|
import io
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
@ -70,7 +71,7 @@ class ValueReport:
|
||||||
assert self.failed
|
assert self.failed
|
||||||
if self.value.size() != self.golden_value.size():
|
if self.value.size() != self.golden_value.size():
|
||||||
return self.context.format_error(
|
return self.context.format_error(
|
||||||
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
f'tensor shape mismatch: got {self.value.size()!r}, expected {self.golden_value.size()!r}'
|
||||||
)
|
)
|
||||||
f = io.StringIO()
|
f = io.StringIO()
|
||||||
p = lambda *x: print(*x, file=f)
|
p = lambda *x: print(*x, file=f)
|
||||||
|
@ -167,17 +168,60 @@ class SingleTestReport:
|
||||||
return f.getvalue()
|
return f.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def report_results(results: List[TestResult], verbose: bool = False):
|
def report_results(results: List[TestResult],
|
||||||
"""Provide a basic error report summarizing various TestResult's.
|
expected_failures: Set[str],
|
||||||
|
verbose: bool = False):
|
||||||
|
"""Print a basic error report summarizing various TestResult's.
|
||||||
|
|
||||||
|
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
|
||||||
|
"lit" testing utility. See
|
||||||
|
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
|
||||||
|
|
||||||
|
The `expected_failures` set should contain the names of tests
|
||||||
|
(according to their `unique_name`) which are expected to fail.
|
||||||
|
The overall passing/failing status of the report requires these to fail
|
||||||
|
in order to succeed (this catches cases where things suddenly
|
||||||
|
start working).
|
||||||
|
|
||||||
If `verbose` is True, then provide an explanation of what failed.
|
If `verbose` is True, then provide an explanation of what failed.
|
||||||
|
|
||||||
|
Returns True if the run resulted in any unexpected pass/fail behavior.
|
||||||
|
Otherwise False.
|
||||||
"""
|
"""
|
||||||
|
summary = collections.Counter()
|
||||||
for result in results:
|
for result in results:
|
||||||
report = SingleTestReport(result, ErrorContext.empty())
|
report = SingleTestReport(result, ErrorContext.empty())
|
||||||
|
expected_failure = result.unique_name in expected_failures
|
||||||
|
if expected_failure:
|
||||||
|
if report.failed:
|
||||||
|
error_str = ''
|
||||||
|
if verbose:
|
||||||
|
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||||
|
print(f'XFAIL - "{result.unique_name}"' + error_str)
|
||||||
|
summary['XFAIL'] += 1
|
||||||
|
else:
|
||||||
|
print(f'XPASS - "{result.unique_name}"')
|
||||||
|
summary['XPASS'] += 1
|
||||||
|
else:
|
||||||
if not report.failed:
|
if not report.failed:
|
||||||
print(f'SUCCESS - "{result.unique_name}"')
|
print(f'PASS - "{result.unique_name}"')
|
||||||
|
summary['PASS'] += 1
|
||||||
else:
|
else:
|
||||||
error_str = ''
|
error_str = ''
|
||||||
if verbose:
|
if verbose:
|
||||||
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||||
print(f'FAILURE - "{result.unique_name}"' + error_str)
|
print(f'FAIL - "{result.unique_name}"' + error_str)
|
||||||
|
summary['FAIL'] += 1
|
||||||
|
|
||||||
|
# Print a summary for easy scanning.
|
||||||
|
print('\nSummary:')
|
||||||
|
KEY_MEANINGS = {
|
||||||
|
'PASS': 'Passed',
|
||||||
|
'FAIL': 'Failed',
|
||||||
|
'XFAIL': 'Expectedly Failed',
|
||||||
|
'XPASS': 'Unexpectedly Passed',
|
||||||
|
}
|
||||||
|
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
|
||||||
|
if summary[key]:
|
||||||
|
print(f' {KEY_MEANINGS[key]}: {summary[key]}')
|
||||||
|
return summary['FAIL'] != 0 or summary['XPASS'] != 0
|
||||||
|
|
|
@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# TODO: Refine messages.
|
# TODO: Refine messages.
|
||||||
# CHECK: SUCCESS - "MmModule_basic"
|
# CHECK: PASS - "MmModule_basic"
|
||||||
@register_test_case(module_factory=lambda: MmModule())
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
def MmModule_basic(module, tu: TestUtils):
|
def MmModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
|
||||||
|
|
||||||
# CHECK: SUCCESS - "MmModule_basic2"
|
# CHECK: PASS - "MmModule_basic2"
|
||||||
@register_test_case(module_factory=lambda: MmModule())
|
@register_test_case(module_factory=lambda: MmModule())
|
||||||
def MmModule_basic2(module, tu: TestUtils):
|
def MmModule_basic2(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||||
|
@ -36,7 +36,7 @@ def MmModule_basic2(module, tu: TestUtils):
|
||||||
def main():
|
def main():
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
report_results(results)
|
report_results(results, set())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -25,7 +25,7 @@ class MmModule(torch.nn.Module):
|
||||||
return 3
|
return 3
|
||||||
|
|
||||||
|
|
||||||
# CHECK: FAILURE - "MmModule_basic"
|
# CHECK: FAIL - "MmModule_basic"
|
||||||
# CHECK: compilation error
|
# CHECK: compilation error
|
||||||
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
||||||
# the offending "return 3".
|
# the offending "return 3".
|
||||||
|
@ -38,7 +38,7 @@ def MmModule_basic(module, tu: TestUtils):
|
||||||
def main():
|
def main():
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
report_results(results, verbose=True)
|
report_results(results, set(), verbose=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -26,7 +26,7 @@ class MmModule(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# TODO: Refine error messages.
|
# TODO: Refine error messages.
|
||||||
# CHECK: FAILURE - "MmModule_basic"
|
# CHECK: FAIL - "MmModule_basic"
|
||||||
# CHECK: @ trace item #0 - call to "forward"
|
# CHECK: @ trace item #0 - call to "forward"
|
||||||
# CHECK: @ output #0
|
# CHECK: @ output #0
|
||||||
# CHECK: ERROR: values mismatch
|
# CHECK: ERROR: values mismatch
|
||||||
|
@ -40,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
|
||||||
def main():
|
def main():
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||||
report_results(results, verbose=True)
|
report_results(results, set(), verbose=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
def is_iree_enabled():
|
||||||
|
try:
|
||||||
|
import iree.runtime
|
||||||
|
import iree.compiler
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
return True
|
|
@ -0,0 +1,45 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mlir.ir import Module
|
||||||
|
|
||||||
|
# A type shared between the result of `NpcompBackend.compile` and the input
|
||||||
|
# to `NpcompBackend.load`. Each backend will likely have a different definition
|
||||||
|
# of this type.
|
||||||
|
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||||
|
|
||||||
|
# A wrapper around a backend-specific loaded program representation
|
||||||
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
|
# Torch modules into appropriate lower-level operations.
|
||||||
|
Invoker = TypeVar('Invoker')
|
||||||
|
|
||||||
|
|
||||||
|
class NpcompBackend(abc.ABC):
|
||||||
|
"""The interface to an npcomp backend.
|
||||||
|
"""
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
||||||
|
The module adheres to the npcomp backend contract
|
||||||
|
(see the VerifyBackendContract pass).
|
||||||
|
|
||||||
|
The compiled artifact can be any type, but must be correctly
|
||||||
|
interpreted by the `load` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||||
|
"""Load the compiled artifact into a uniformly invokable form.
|
||||||
|
|
||||||
|
The compiled artifact is the result of a previous call to `compile`.
|
||||||
|
|
||||||
|
See the description of `Invoker` for the requirements on the returned
|
||||||
|
type.
|
||||||
|
"""
|
|
@ -5,6 +5,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from mlir.ir import *
|
from mlir.ir import *
|
||||||
from mlir.passmanager import *
|
from mlir.passmanager import *
|
||||||
|
@ -12,8 +13,10 @@ from npcomp.compiler.utils import logging
|
||||||
import iree.runtime as ireert
|
import iree.runtime as ireert
|
||||||
import iree.compiler as ireec
|
import iree.compiler as ireec
|
||||||
|
|
||||||
|
from .abc import NpcompBackend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompilerBackend",
|
"IreeNpcompBackend",
|
||||||
]
|
]
|
||||||
|
|
||||||
PREPARE_FOR_IREE_PASSES = (
|
PREPARE_FOR_IREE_PASSES = (
|
||||||
|
@ -34,6 +37,8 @@ class IreeModuleInvoker:
|
||||||
|
|
||||||
def invoke(*args):
|
def invoke(*args):
|
||||||
results = self._iree_module[function_name](*args)
|
results = self._iree_module[function_name](*args)
|
||||||
|
if isinstance(results, np.ndarray):
|
||||||
|
return results
|
||||||
if len(results) == 1:
|
if len(results) == 1:
|
||||||
# De-tuple.
|
# De-tuple.
|
||||||
return results[0]
|
return results[0]
|
||||||
|
@ -58,7 +63,7 @@ class TorchIreeModuleInvoker(IreeModuleInvoker):
|
||||||
return invoke
|
return invoke
|
||||||
|
|
||||||
|
|
||||||
class CompilerBackend:
|
class IreeNpcompBackend(NpcompBackend):
|
||||||
"""Main entry-point for the backend."""
|
"""Main entry-point for the backend."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -67,9 +72,8 @@ class CompilerBackend:
|
||||||
|
|
||||||
def compile(self, imported_module: Module):
|
def compile(self, imported_module: Module):
|
||||||
"""Compiles an imported module, with a flat list of functions.
|
"""Compiles an imported module, with a flat list of functions.
|
||||||
The module is expected to be in "TCP + scalar code" form.
|
The module is expected to conform to the npcomp backend contract.
|
||||||
TODO: More clearly define the backend contract. Generally this will
|
See the VerifyBackendContract pass for more details.
|
||||||
extend to support globals, lists, and other stuff.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
imported_module: The MLIR module consisting of funcs in the torch
|
imported_module: The MLIR module consisting of funcs in the torch
|
||||||
|
@ -97,12 +101,13 @@ class CompilerBackend:
|
||||||
# Backend.
|
# Backend.
|
||||||
binary = ireec.compile_str(str(imported_module),
|
binary = ireec.compile_str(str(imported_module),
|
||||||
target_backends=["dylib-llvm-aot"])
|
target_backends=["dylib-llvm-aot"])
|
||||||
iree_config = ireert.Config(driver_name="dylib")
|
return binary
|
||||||
|
|
||||||
iree_module = ireert.load_module(ireert.VmModule.from_flatbuffer(binary),
|
|
||||||
config=iree_config)
|
|
||||||
return iree_module
|
|
||||||
|
|
||||||
def load(self, iree_module) -> TorchIreeModuleInvoker:
|
def load(self, iree_module) -> TorchIreeModuleInvoker:
|
||||||
"""Loads a compiled artifact into the runtime."""
|
"""Loads a compiled artifact into the runtime."""
|
||||||
return TorchIreeModuleInvoker(iree_module)
|
vm_module = ireert.VmModule.from_flatbuffer(iree_module)
|
||||||
|
|
||||||
|
iree_config = ireert.Config(driver_name="dylib")
|
||||||
|
ctx = ireert.SystemContext(config=iree_config)
|
||||||
|
ctx.add_vm_module(vm_module)
|
||||||
|
return TorchIreeModuleInvoker(ctx.modules.module)
|
||||||
|
|
|
@ -10,10 +10,11 @@ from mlir.ir import *
|
||||||
from mlir.passmanager import *
|
from mlir.passmanager import *
|
||||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||||
from npcomp.compiler.utils import logging
|
from npcomp.compiler.utils import logging
|
||||||
|
from .abc import NpcompBackend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"is_enabled",
|
"is_enabled",
|
||||||
"CompilerBackend",
|
"RefjitNpcompBackend",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Re-export.
|
# Re-export.
|
||||||
|
@ -34,7 +35,7 @@ class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker):
|
||||||
return invoke
|
return invoke
|
||||||
|
|
||||||
|
|
||||||
class CompilerBackend:
|
class RefjitNpcompBackend(NpcompBackend):
|
||||||
"""Main entry-point for the backend."""
|
"""Main entry-point for the backend."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
src_dir="$(realpath $(dirname $0)/..)"
|
||||||
|
|
||||||
|
cd "$src_dir"
|
||||||
|
source .env
|
||||||
|
python -m frontends.pytorch.e2e_testing.torchscript.main "$@"
|
Loading…
Reference in New Issue