mirror of https://github.com/llvm/torch-mlir
Add --external-config option to tools/torchscript_e2e_test.sh
This is a simple way for externals to plug their backends into the test suite. They just implement the `TestConfig` class for their backend and write a small script that exposes it. I have a pending PR for iree-samples that successfully integrates this.pull/348/head
parent
98ba255288
commit
f69630255a
|
@ -9,7 +9,7 @@ import pickle
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from torch_mlir_e2e_test.torchscript.framework import run_tests
|
from torch_mlir_e2e_test.torchscript.framework import TestConfig, run_tests
|
||||||
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
from torch_mlir_e2e_test.torchscript.reporting import report_results
|
||||||
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
from torch_mlir_e2e_test.torchscript.registry import GLOBAL_TEST_REGISTRY
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from torch_mlir_e2e_test.torchscript.configs import (
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
|
||||||
from .xfail_sets import XFAIL_SETS
|
from .xfail_sets import XFAIL_SETS, COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||||
|
@ -35,9 +35,7 @@ from . import elementwise
|
||||||
from . import reduction
|
from . import reduction
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
# TODO: Allow pulling in an out-of-tree backend, so downstream can easily
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'external']
|
||||||
# plug into the e2e tests.
|
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend']
|
|
||||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||||
parser.add_argument('-c', '--config',
|
parser.add_argument('-c', '--config',
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -47,6 +45,17 @@ Meaning of options:
|
||||||
"refbackend": run through torch-mlir's RefBackend.
|
"refbackend": run through torch-mlir's RefBackend.
|
||||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
|
"external": use an external backend, specified by the `--external-backend` option.
|
||||||
|
''')
|
||||||
|
parser.add_argument('--external-config',
|
||||||
|
help=f'''
|
||||||
|
Specifies a path to a Python file, which will be `exec`'ed.
|
||||||
|
The file has the following contract:
|
||||||
|
- The global variable `config` should be set to an instance of `TestConfig`.
|
||||||
|
- `xfail_set` should be set to a set of test unique identifiers that are
|
||||||
|
expected to fail. The global `COMMON_TORCH_MLIR_LOWERING_XFAILS` provides
|
||||||
|
a common set of xfails that won't work on backends because torch-mlir
|
||||||
|
itself does not handle them.
|
||||||
''')
|
''')
|
||||||
parser.add_argument('-f', '--filter', default='.*', help='''
|
parser.add_argument('-f', '--filter', default='.*', help='''
|
||||||
Regular expression specifying which tests to include in this run.
|
Regular expression specifying which tests to include in this run.
|
||||||
|
@ -71,10 +80,31 @@ def main():
|
||||||
# Find the selected config.
|
# Find the selected config.
|
||||||
if args.config == 'refbackend':
|
if args.config == 'refbackend':
|
||||||
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||||
|
xfail_set = XFAIL_SETS['refbackend']
|
||||||
elif args.config == 'native_torch':
|
elif args.config == 'native_torch':
|
||||||
config = NativeTorchTestConfig()
|
config = NativeTorchTestConfig()
|
||||||
|
xfail_set = XFAIL_SETS['native_torch']
|
||||||
elif args.config == 'torchscript':
|
elif args.config == 'torchscript':
|
||||||
config = TorchScriptTestConfig()
|
config = TorchScriptTestConfig()
|
||||||
|
xfail_set = XFAIL_SETS['torchscript']
|
||||||
|
elif args.config == 'external':
|
||||||
|
with open(args.external_config, 'r') as f:
|
||||||
|
code = compile(f.read(), args.external_config, 'exec')
|
||||||
|
exec_globals = {
|
||||||
|
'COMMON_TORCH_MLIR_LOWERING_XFAILS': COMMON_TORCH_MLIR_LOWERING_XFAILS}
|
||||||
|
exec(code, exec_globals)
|
||||||
|
config = exec_globals.get('config')
|
||||||
|
xfail_set = exec_globals.get('xfail_set')
|
||||||
|
if config is None or not isinstance(config, TestConfig):
|
||||||
|
print(
|
||||||
|
f'ERROR: the script {args.external_config} did not set a global variable `config`'
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if xfail_set is None:
|
||||||
|
print(
|
||||||
|
f'ERROR: the script {args.external_config} did not set a global variable `xfail_set`'
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
all_tests = list(GLOBAL_TEST_REGISTRY)
|
all_tests = list(GLOBAL_TEST_REGISTRY)
|
||||||
if args.serialized_test_dir:
|
if args.serialized_test_dir:
|
||||||
|
@ -101,7 +131,7 @@ def main():
|
||||||
results = run_tests(tests, config)
|
results = run_tests(tests, config)
|
||||||
|
|
||||||
# Report the test results.
|
# Report the test results.
|
||||||
failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
|
failed = report_results(results, xfail_set, args.verbose)
|
||||||
sys.exit(1 if failed else 0)
|
sys.exit(1 if failed else 0)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -15,11 +15,11 @@ XFAIL_SETS = {}
|
||||||
# Lists of tests that fail to even reach the backends.
|
# Lists of tests that fail to even reach the backends.
|
||||||
# These represent further work needed in torch-mlir to lower them properly
|
# These represent further work needed in torch-mlir to lower them properly
|
||||||
# to the backend contract.
|
# to the backend contract.
|
||||||
_common_torch_mlir_lowering_xfails = {
|
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
'QuantizedMLP_basic',
|
'QuantizedMLP_basic',
|
||||||
}
|
}
|
||||||
|
|
||||||
XFAIL_SETS['refbackend'] = _common_torch_mlir_lowering_xfails
|
XFAIL_SETS['refbackend'] = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
|
|
||||||
XFAIL_SETS['torchscript'] = {}
|
XFAIL_SETS['torchscript'] = {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue