Add IREE support in TorchScript e2e tests.

- Add support for "expected failures" in test reporting. The new error
  reports look like
  [this](https://gist.github.com/silvasean/6ffd95e1d55302b699673da201da210d).
  - We will now be able to put these tests into CI, since the harness
    understand which tests are expected to pass and fail.
- Refactor RefBackendTestConfig to NpcompBackendTestConfig which
  supports both RefBackend and IREE.
- Add instructions for installing IREE dependencies (both from packages
  and for local builds of IREE)
- Add `tools/torchscript_e2e_test.sh` for invoking the e2e test
  harness (this makes invoking a bit easier, as it doesn't rely on a
  loose Python invocation).
pull/243/head
Sean Silva 2021-06-30 14:13:21 -07:00
parent 79928cd2dd
commit d5108b9dc1
22 changed files with 284 additions and 64 deletions

View File

@ -164,6 +164,37 @@ cd /src/mlir-npcomp
cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch
```
### IREE Backend (from IREE packages)
```shell
# We currently track and require the latest snapshot.
pip3 install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases
# Run TorchScript E2E tests targeting IREE.
# Make sure to run "PyTorch Frontend" setup instructions first.
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
```
### IREE Backend (from local IREE build)
This configuration is useful for iterating locally, as you can
poke/debug/rebuild things in IREE.
```shell
# Locally build IREE.
# See https://google.github.io/iree/building-from-source/getting-started/
# Make sure IREE is configured with `-DIREE_BUILD_PYTHON_BINDINGS=ON`.
echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env
# Run TorchScript E2E tests targeting IREE.
# Make sure to run "PyTorch Frontend" setup instructions first.
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
```
### VSCode with a Docker Dev Image
#### Start a docker dev container based on our image

View File

@ -12,31 +12,39 @@ from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
# Available test configs.
from torch_mlir.torchscript.e2e_test.configs import (
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
)
from npcomp.compiler.pytorch.backend import is_iree_enabled
IREE_ENABLED = is_iree_enabled()
if IREE_ENABLED:
from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend
from npcomp.compiler.pytorch.backend.refjit import RefjitNpcompBackend
from .xfail_sets import XFAIL_SETS
# Import tests to register them in the global registry.
# TODO: Use a relative import.
# That requires invoking this file as a "package" though, which makes it
# not possible to just do `python main.py`. Instead, it requires something
# like `python -m tochscript.main` which is annoying because it can only
# be run from a specific directory.
# TODO: Find out best practices for python "main" files.
import basic
import vision_models
import mlp
import batchnorm
import quantized_models
import elementwise
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
# this script.
from . import basic
from . import vision_models
from . import mlp
from . import batchnorm
from . import quantized_models
from . import elementwise
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend']
if IREE_ENABLED:
config_choices += ['iree']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'],
choices=config_choices,
default='refbackend',
help='''
help=f'''
Meaning of options:
"refbackend": run through npcomp's RefBackend.
"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
@ -54,7 +62,9 @@ def main():
# Find the selected config.
if args.config == 'refbackend':
config = RefBackendTestConfig()
config = NpcompBackendTestConfig(RefjitNpcompBackend())
elif args.config == 'iree':
config = NpcompBackendTestConfig(IreeNpcompBackend())
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
elif args.config == 'torchscript':
@ -78,7 +88,8 @@ def main():
results = run_tests(tests, config)
# Report the test results.
report_results(results, args.verbose)
failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
sys.exit(1 if failed else 0)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,29 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This file describes the sets of tests expected to fail for each config.
# This information is deliberately kept in a side table, rather than
# in-situ on the test, as a deliberate layering decision: tests should
# have unique keys to identify them and enable side tables of various kinds
# (this includes down into lower parts of the stack, where a side table
# might be used to keep more elaborate sets of testing configurations).
XFAIL_SETS = {}
# Lists of tests that fail to even reach the backends.
# These represent further work needed in npcomp to lower them properly
# to the backend contract.
_common_npcomp_lowering_xfails = {
'ResNet18Module_basic',
'QuantizedMLP_basic',
}
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
# https://github.com/google/iree/issues/6368
'MmDagModule_basic',
'Mlp1LayerModule_basic',
'Mlp2LayerModule_basic',
}

View File

@ -21,7 +21,7 @@ with mb.capture_function("cos", [input]) as f:
result = torch.cos(input)
f.returns([result])
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
logging.debug(f"Executing jit_module.cos")

View File

@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0, arg1]) as f:
f.returns([fun(arg0, arg1)])
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)

View File

@ -22,7 +22,7 @@ with mb.capture_function("mm", [lhs, rhs]) as f:
result = torch.mm(lhs, rhs)
f.returns([result])
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)

View File

@ -28,7 +28,7 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
result = mul_maximum(lhs, rhs, threshold, bias)
f.returns([result])
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,

View File

@ -26,7 +26,7 @@ mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0]) as f:
f.returns([fun(arg0)])
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)

View File

@ -48,7 +48,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

View File

@ -40,7 +40,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()
backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

View File

@ -34,13 +34,13 @@ class_annotator.exportNone(recursivescriptmodule._c._type())
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
None,
([2, 3, -1], torch.float32)
([2, 3, -1], torch.float32, True)
])
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()
backend = iree.CompilerBackend()
backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

View File

@ -2,6 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .ref_backend import RefBackendTestConfig
from .npcomp_backend import NpcompBackendTestConfig
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig

View File

@ -14,15 +14,32 @@ from mlir.passmanager import PassManager
import torch_mlir
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir.torchscript.annotations import extract_annotations
class PrettyErrorReportForIrOperation(object):
def __init__(self, module, module_name_for_ir_dump: str):
sys.stderr = StringIO()
self.filename_for_ir_dump = os.path.join(tempfile.gettempdir(),
module_name_for_ir_dump + '.mlir')
self.asm_for_error_report = module.get_asm(
large_elements_limit=10, enable_debug_info=True)
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
with open(self.filename_for_ir_dump, 'w') as f:
f.write(self.asm_for_error_report)
class RefBackendTestConfig(TestConfig):
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
def __init__(self):
class NpcompBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with npcomp.
This class handles all the common lowering that npcomp does before reaching
its backends.
"""
def __init__(self, backend: NpcompBackend):
super().__init__()
self.backend = refjit.CompilerBackend()
self.backend = backend
def compile(self, program: torch.nn.Module) -> Any:
mb = torch_mlir.ModuleBuilder()
@ -79,14 +96,36 @@ $ npcomp-opt -{pipeline_str} {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
return self.backend.compile(mb.module)
try:
sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
return self.backend.compile(mb.module)
except Exception as e:
filename = os.path.join(tempfile.gettempdir(),
scripted.original_name + '.mlir')
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
## Exception:
{e}
## Stderr:
{sys.stderr.getvalue()}
## Input IR has been saved in {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
def run(self, artifact: Any, trace: Trace) -> Trace:
jit_module = self.backend.load(artifact)
backend_module = self.backend.load(artifact)
result: Trace = []
for item in trace:
numpy_inputs = [t.numpy() for t in item.inputs]
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
if isinstance(outputs, np.ndarray):
outputs = [outputs]
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]

View File

@ -5,8 +5,9 @@
Utilities for reporting the results of the test framework.
"""
from typing import List, Optional
from typing import List, Optional, Set
import collections
import io
import textwrap
@ -70,7 +71,7 @@ class ValueReport:
assert self.failed
if self.value.size() != self.golden_value.size():
return self.context.format_error(
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
f'tensor shape mismatch: got {self.value.size()!r}, expected {self.golden_value.size()!r}'
)
f = io.StringIO()
p = lambda *x: print(*x, file=f)
@ -167,17 +168,60 @@ class SingleTestReport:
return f.getvalue()
def report_results(results: List[TestResult], verbose: bool = False):
"""Provide a basic error report summarizing various TestResult's.
def report_results(results: List[TestResult],
expected_failures: Set[str],
verbose: bool = False):
"""Print a basic error report summarizing various TestResult's.
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
"lit" testing utility. See
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
The `expected_failures` set should contain the names of tests
(according to their `unique_name`) which are expected to fail.
The overall passing/failing status of the report requires these to fail
in order to succeed (this catches cases where things suddenly
start working).
If `verbose` is True, then provide an explanation of what failed.
Returns True if the run resulted in any unexpected pass/fail behavior.
Otherwise False.
"""
summary = collections.Counter()
for result in results:
report = SingleTestReport(result, ErrorContext.empty())
if not report.failed:
print(f'SUCCESS - "{result.unique_name}"')
expected_failure = result.unique_name in expected_failures
if expected_failure:
if report.failed:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'XFAIL - "{result.unique_name}"' + error_str)
summary['XFAIL'] += 1
else:
print(f'XPASS - "{result.unique_name}"')
summary['XPASS'] += 1
else:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'FAILURE - "{result.unique_name}"' + error_str)
if not report.failed:
print(f'PASS - "{result.unique_name}"')
summary['PASS'] += 1
else:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'FAIL - "{result.unique_name}"' + error_str)
summary['FAIL'] += 1
# Print a summary for easy scanning.
print('\nSummary:')
KEY_MEANINGS = {
'PASS': 'Passed',
'FAIL': 'Failed',
'XFAIL': 'Expectedly Failed',
'XPASS': 'Unexpectedly Passed',
}
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
if summary[key]:
print(f' {KEY_MEANINGS[key]}: {summary[key]}')
return summary['FAIL'] != 0 or summary['XPASS'] != 0

View File

@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
# TODO: Refine messages.
# CHECK: SUCCESS - "MmModule_basic"
# CHECK: PASS - "MmModule_basic"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# CHECK: SUCCESS - "MmModule_basic2"
# CHECK: PASS - "MmModule_basic2"
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic2(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
@ -36,7 +36,7 @@ def MmModule_basic2(module, tu: TestUtils):
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results)
report_results(results, set())
if __name__ == '__main__':

View File

@ -25,7 +25,7 @@ class MmModule(torch.nn.Module):
return 3
# CHECK: FAILURE - "MmModule_basic"
# CHECK: FAIL - "MmModule_basic"
# CHECK: compilation error
# Assume that the diagnostic from the TorchScript compiler will at least contain
# the offending "return 3".
@ -38,7 +38,7 @@ def MmModule_basic(module, tu: TestUtils):
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, verbose=True)
report_results(results, set(), verbose=True)
if __name__ == '__main__':

View File

@ -26,7 +26,7 @@ class MmModule(torch.nn.Module):
# TODO: Refine error messages.
# CHECK: FAILURE - "MmModule_basic"
# CHECK: FAIL - "MmModule_basic"
# CHECK: @ trace item #0 - call to "forward"
# CHECK: @ output #0
# CHECK: ERROR: values mismatch
@ -40,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
def main():
config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, verbose=True)
report_results(results, set(), verbose=True)
if __name__ == '__main__':

View File

@ -0,0 +1,7 @@
def is_iree_enabled():
try:
import iree.runtime
import iree.compiler
except:
return False
return True

View File

@ -0,0 +1,45 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import abc
from typing import TypeVar
import torch
from mlir.ir import Module
# A type shared between the result of `NpcompBackend.compile` and the input
# to `NpcompBackend.load`. Each backend will likely have a different definition
# of this type.
CompiledArtifact = TypeVar('CompiledArtifact')
# A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker')
class NpcompBackend(abc.ABC):
"""The interface to an npcomp backend.
"""
@abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact.
The module adheres to the npcomp backend contract
(see the VerifyBackendContract pass).
The compiled artifact can be any type, but must be correctly
interpreted by the `load` method.
"""
@abc.abstractmethod
def load(self, artifact: CompiledArtifact) -> Invoker:
"""Load the compiled artifact into a uniformly invokable form.
The compiled artifact is the result of a previous call to `compile`.
See the description of `Invoker` for the requirements on the returned
type.
"""

View File

@ -5,6 +5,7 @@
import os
import torch
import numpy as np
from mlir.ir import *
from mlir.passmanager import *
@ -12,8 +13,10 @@ from npcomp.compiler.utils import logging
import iree.runtime as ireert
import iree.compiler as ireec
from .abc import NpcompBackend
__all__ = [
"CompilerBackend",
"IreeNpcompBackend",
]
PREPARE_FOR_IREE_PASSES = (
@ -34,6 +37,8 @@ class IreeModuleInvoker:
def invoke(*args):
results = self._iree_module[function_name](*args)
if isinstance(results, np.ndarray):
return results
if len(results) == 1:
# De-tuple.
return results[0]
@ -58,7 +63,7 @@ class TorchIreeModuleInvoker(IreeModuleInvoker):
return invoke
class CompilerBackend:
class IreeNpcompBackend(NpcompBackend):
"""Main entry-point for the backend."""
def __init__(self):
@ -67,9 +72,8 @@ class CompilerBackend:
def compile(self, imported_module: Module):
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in "TCP + scalar code" form.
TODO: More clearly define the backend contract. Generally this will
extend to support globals, lists, and other stuff.
The module is expected to conform to the npcomp backend contract.
See the VerifyBackendContract pass for more details.
Args:
imported_module: The MLIR module consisting of funcs in the torch
@ -97,12 +101,13 @@ class CompilerBackend:
# Backend.
binary = ireec.compile_str(str(imported_module),
target_backends=["dylib-llvm-aot"])
iree_config = ireert.Config(driver_name="dylib")
iree_module = ireert.load_module(ireert.VmModule.from_flatbuffer(binary),
config=iree_config)
return iree_module
return binary
def load(self, iree_module) -> TorchIreeModuleInvoker:
"""Loads a compiled artifact into the runtime."""
return TorchIreeModuleInvoker(iree_module)
vm_module = ireert.VmModule.from_flatbuffer(iree_module)
iree_config = ireert.Config(driver_name="dylib")
ctx = ireert.SystemContext(config=iree_config)
ctx.add_vm_module(vm_module)
return TorchIreeModuleInvoker(ctx.modules.module)

View File

@ -10,10 +10,11 @@ from mlir.ir import *
from mlir.passmanager import *
from npcomp.compiler.generic.backend import refjit as refjit_backend
from npcomp.compiler.utils import logging
from .abc import NpcompBackend
__all__ = [
"is_enabled",
"CompilerBackend",
"RefjitNpcompBackend",
]
# Re-export.
@ -34,7 +35,7 @@ class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker):
return invoke
class CompilerBackend:
class RefjitNpcompBackend(NpcompBackend):
"""Main entry-point for the backend."""
def __init__(self):

View File

@ -0,0 +1,8 @@
#!/bin/bash
set -euo pipefail
src_dir="$(realpath $(dirname $0)/..)"
cd "$src_dir"
source .env
python -m frontends.pytorch.e2e_testing.torchscript.main "$@"