mirror of https://github.com/llvm/torch-mlir
Add IREE support in TorchScript e2e tests.
- Add support for "expected failures" in test reporting. The new error reports look like [this](https://gist.github.com/silvasean/6ffd95e1d55302b699673da201da210d). - We will now be able to put these tests into CI, since the harness understand which tests are expected to pass and fail. - Refactor RefBackendTestConfig to NpcompBackendTestConfig which supports both RefBackend and IREE. - Add instructions for installing IREE dependencies (both from packages and for local builds of IREE) - Add `tools/torchscript_e2e_test.sh` for invoking the e2e test harness (this makes invoking a bit easier, as it doesn't rely on a loose Python invocation).pull/243/head
parent
79928cd2dd
commit
d5108b9dc1
31
README.md
31
README.md
|
@ -164,6 +164,37 @@ cd /src/mlir-npcomp
|
|||
cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch
|
||||
```
|
||||
|
||||
### IREE Backend (from IREE packages)
|
||||
|
||||
```shell
|
||||
# We currently track and require the latest snapshot.
|
||||
pip3 install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases
|
||||
|
||||
|
||||
|
||||
# Run TorchScript E2E tests targeting IREE.
|
||||
# Make sure to run "PyTorch Frontend" setup instructions first.
|
||||
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
|
||||
```
|
||||
|
||||
### IREE Backend (from local IREE build)
|
||||
|
||||
This configuration is useful for iterating locally, as you can
|
||||
poke/debug/rebuild things in IREE.
|
||||
|
||||
```shell
|
||||
# Locally build IREE.
|
||||
# See https://google.github.io/iree/building-from-source/getting-started/
|
||||
# Make sure IREE is configured with `-DIREE_BUILD_PYTHON_BINDINGS=ON`.
|
||||
|
||||
echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env
|
||||
|
||||
# Run TorchScript E2E tests targeting IREE.
|
||||
# Make sure to run "PyTorch Frontend" setup instructions first.
|
||||
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
|
||||
```
|
||||
|
||||
|
||||
### VSCode with a Docker Dev Image
|
||||
|
||||
#### Start a docker dev container based on our image
|
||||
|
|
|
@ -12,31 +12,39 @@ from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
|
|||
|
||||
# Available test configs.
|
||||
from torch_mlir.torchscript.e2e_test.configs import (
|
||||
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
|
||||
)
|
||||
|
||||
from npcomp.compiler.pytorch.backend import is_iree_enabled
|
||||
IREE_ENABLED = is_iree_enabled()
|
||||
if IREE_ENABLED:
|
||||
from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend
|
||||
from npcomp.compiler.pytorch.backend.refjit import RefjitNpcompBackend
|
||||
|
||||
from .xfail_sets import XFAIL_SETS
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
# TODO: Use a relative import.
|
||||
# That requires invoking this file as a "package" though, which makes it
|
||||
# not possible to just do `python main.py`. Instead, it requires something
|
||||
# like `python -m tochscript.main` which is annoying because it can only
|
||||
# be run from a specific directory.
|
||||
# TODO: Find out best practices for python "main" files.
|
||||
import basic
|
||||
import vision_models
|
||||
import mlp
|
||||
import batchnorm
|
||||
import quantized_models
|
||||
import elementwise
|
||||
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
|
||||
# this script.
|
||||
from . import basic
|
||||
from . import vision_models
|
||||
from . import mlp
|
||||
from . import batchnorm
|
||||
from . import quantized_models
|
||||
from . import elementwise
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend']
|
||||
if IREE_ENABLED:
|
||||
config_choices += ['iree']
|
||||
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
|
||||
parser.add_argument('--config',
|
||||
choices=['native_torch', 'torchscript', 'refbackend'],
|
||||
choices=config_choices,
|
||||
default='refbackend',
|
||||
help='''
|
||||
help=f'''
|
||||
Meaning of options:
|
||||
"refbackend": run through npcomp's RefBackend.
|
||||
"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
''')
|
||||
|
@ -54,7 +62,9 @@ def main():
|
|||
|
||||
# Find the selected config.
|
||||
if args.config == 'refbackend':
|
||||
config = RefBackendTestConfig()
|
||||
config = NpcompBackendTestConfig(RefjitNpcompBackend())
|
||||
elif args.config == 'iree':
|
||||
config = NpcompBackendTestConfig(IreeNpcompBackend())
|
||||
elif args.config == 'native_torch':
|
||||
config = NativeTorchTestConfig()
|
||||
elif args.config == 'torchscript':
|
||||
|
@ -78,7 +88,8 @@ def main():
|
|||
results = run_tests(tests, config)
|
||||
|
||||
# Report the test results.
|
||||
report_results(results, args.verbose)
|
||||
failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
|
||||
sys.exit(1 if failed else 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# This file describes the sets of tests expected to fail for each config.
|
||||
# This information is deliberately kept in a side table, rather than
|
||||
# in-situ on the test, as a deliberate layering decision: tests should
|
||||
# have unique keys to identify them and enable side tables of various kinds
|
||||
# (this includes down into lower parts of the stack, where a side table
|
||||
# might be used to keep more elaborate sets of testing configurations).
|
||||
|
||||
XFAIL_SETS = {}
|
||||
|
||||
# Lists of tests that fail to even reach the backends.
|
||||
# These represent further work needed in npcomp to lower them properly
|
||||
# to the backend contract.
|
||||
_common_npcomp_lowering_xfails = {
|
||||
'ResNet18Module_basic',
|
||||
'QuantizedMLP_basic',
|
||||
}
|
||||
|
||||
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
|
||||
|
||||
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
|
||||
# https://github.com/google/iree/issues/6368
|
||||
'MmDagModule_basic',
|
||||
'Mlp1LayerModule_basic',
|
||||
'Mlp2LayerModule_basic',
|
||||
}
|
|
@ -21,7 +21,7 @@ with mb.capture_function("cos", [input]) as f:
|
|||
result = torch.cos(input)
|
||||
f.returns([result])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||
|
||||
logging.debug(f"Executing jit_module.cos")
|
||||
|
|
|
@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
|
|||
with mb.capture_function("test", [arg0, arg1]) as f:
|
||||
f.returns([fun(arg0, arg1)])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||
|
||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
||||
|
|
|
@ -22,7 +22,7 @@ with mb.capture_function("mm", [lhs, rhs]) as f:
|
|||
result = torch.mm(lhs, rhs)
|
||||
f.returns([result])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||
|
||||
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
|
||||
|
|
|
@ -28,7 +28,7 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
|
|||
result = mul_maximum(lhs, rhs, threshold, bias)
|
||||
f.returns([result])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||
|
||||
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,
|
||||
|
|
|
@ -26,7 +26,7 @@ mb = torch_mlir.ModuleBuilder()
|
|||
with mb.capture_function("test", [arg0]) as f:
|
||||
f.returns([fun(arg0)])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
||||
|
||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
||||
|
|
|
@ -48,7 +48,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
|||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
#mb.module.operation.print()
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
|||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
#mb.module.operation.print()
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
|
|
|
@ -34,13 +34,13 @@ class_annotator.exportNone(recursivescriptmodule._c._type())
|
|||
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
|
||||
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
|
||||
None,
|
||||
([2, 3, -1], torch.float32)
|
||||
([2, 3, -1], torch.float32, True)
|
||||
])
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
#mb.module.operation.print()
|
||||
|
||||
backend = iree.CompilerBackend()
|
||||
backend = iree.IreeNpcompBackend()
|
||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
|
|
|
@ -2,6 +2,6 @@
|
|||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .ref_backend import RefBackendTestConfig
|
||||
from .npcomp_backend import NpcompBackendTestConfig
|
||||
from .native_torch import NativeTorchTestConfig
|
||||
from .torchscript import TorchScriptTestConfig
|
||||
|
|
|
@ -14,15 +14,32 @@ from mlir.passmanager import PassManager
|
|||
|
||||
import torch_mlir
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
|
||||
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir.torchscript.annotations import extract_annotations
|
||||
|
||||
class PrettyErrorReportForIrOperation(object):
|
||||
def __init__(self, module, module_name_for_ir_dump: str):
|
||||
sys.stderr = StringIO()
|
||||
self.filename_for_ir_dump = os.path.join(tempfile.gettempdir(),
|
||||
module_name_for_ir_dump + '.mlir')
|
||||
self.asm_for_error_report = module.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, type, value, traceback):
|
||||
with open(self.filename_for_ir_dump, 'w') as f:
|
||||
f.write(self.asm_for_error_report)
|
||||
|
||||
class RefBackendTestConfig(TestConfig):
|
||||
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
|
||||
def __init__(self):
|
||||
class NpcompBackendTestConfig(TestConfig):
|
||||
"""Base class for TestConfig's that are implemented with npcomp.
|
||||
|
||||
This class handles all the common lowering that npcomp does before reaching
|
||||
its backends.
|
||||
"""
|
||||
def __init__(self, backend: NpcompBackend):
|
||||
super().__init__()
|
||||
self.backend = refjit.CompilerBackend()
|
||||
self.backend = backend
|
||||
|
||||
def compile(self, program: torch.nn.Module) -> Any:
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
@ -79,14 +96,36 @@ $ npcomp-opt -{pipeline_str} {filename}
|
|||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
return self.backend.compile(mb.module)
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = mb.module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
return self.backend.compile(mb.module)
|
||||
except Exception as e:
|
||||
filename = os.path.join(tempfile.gettempdir(),
|
||||
scripted.original_name + '.mlir')
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
raise Exception(f"""
|
||||
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
|
||||
## Exception:
|
||||
{e}
|
||||
|
||||
## Stderr:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
## Input IR has been saved in {filename}
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||
jit_module = self.backend.load(artifact)
|
||||
backend_module = self.backend.load(artifact)
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
numpy_inputs = [t.numpy() for t in item.inputs]
|
||||
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
|
||||
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
|
||||
if isinstance(outputs, np.ndarray):
|
||||
outputs = [outputs]
|
||||
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
|
|
@ -5,8 +5,9 @@
|
|||
Utilities for reporting the results of the test framework.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import collections
|
||||
import io
|
||||
import textwrap
|
||||
|
||||
|
@ -70,7 +71,7 @@ class ValueReport:
|
|||
assert self.failed
|
||||
if self.value.size() != self.golden_value.size():
|
||||
return self.context.format_error(
|
||||
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
|
||||
f'tensor shape mismatch: got {self.value.size()!r}, expected {self.golden_value.size()!r}'
|
||||
)
|
||||
f = io.StringIO()
|
||||
p = lambda *x: print(*x, file=f)
|
||||
|
@ -167,17 +168,60 @@ class SingleTestReport:
|
|||
return f.getvalue()
|
||||
|
||||
|
||||
def report_results(results: List[TestResult], verbose: bool = False):
|
||||
"""Provide a basic error report summarizing various TestResult's.
|
||||
def report_results(results: List[TestResult],
|
||||
expected_failures: Set[str],
|
||||
verbose: bool = False):
|
||||
"""Print a basic error report summarizing various TestResult's.
|
||||
|
||||
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
|
||||
"lit" testing utility. See
|
||||
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
|
||||
|
||||
The `expected_failures` set should contain the names of tests
|
||||
(according to their `unique_name`) which are expected to fail.
|
||||
The overall passing/failing status of the report requires these to fail
|
||||
in order to succeed (this catches cases where things suddenly
|
||||
start working).
|
||||
|
||||
If `verbose` is True, then provide an explanation of what failed.
|
||||
|
||||
Returns True if the run resulted in any unexpected pass/fail behavior.
|
||||
Otherwise False.
|
||||
"""
|
||||
summary = collections.Counter()
|
||||
for result in results:
|
||||
report = SingleTestReport(result, ErrorContext.empty())
|
||||
if not report.failed:
|
||||
print(f'SUCCESS - "{result.unique_name}"')
|
||||
expected_failure = result.unique_name in expected_failures
|
||||
if expected_failure:
|
||||
if report.failed:
|
||||
error_str = ''
|
||||
if verbose:
|
||||
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||
print(f'XFAIL - "{result.unique_name}"' + error_str)
|
||||
summary['XFAIL'] += 1
|
||||
else:
|
||||
print(f'XPASS - "{result.unique_name}"')
|
||||
summary['XPASS'] += 1
|
||||
else:
|
||||
error_str = ''
|
||||
if verbose:
|
||||
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||
print(f'FAILURE - "{result.unique_name}"' + error_str)
|
||||
if not report.failed:
|
||||
print(f'PASS - "{result.unique_name}"')
|
||||
summary['PASS'] += 1
|
||||
else:
|
||||
error_str = ''
|
||||
if verbose:
|
||||
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
|
||||
print(f'FAIL - "{result.unique_name}"' + error_str)
|
||||
summary['FAIL'] += 1
|
||||
|
||||
# Print a summary for easy scanning.
|
||||
print('\nSummary:')
|
||||
KEY_MEANINGS = {
|
||||
'PASS': 'Passed',
|
||||
'FAIL': 'Failed',
|
||||
'XFAIL': 'Expectedly Failed',
|
||||
'XPASS': 'Unexpectedly Passed',
|
||||
}
|
||||
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
|
||||
if summary[key]:
|
||||
print(f' {KEY_MEANINGS[key]}: {summary[key]}')
|
||||
return summary['FAIL'] != 0 or summary['XPASS'] != 0
|
||||
|
|
|
@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
|
|||
|
||||
|
||||
# TODO: Refine messages.
|
||||
# CHECK: SUCCESS - "MmModule_basic"
|
||||
# CHECK: PASS - "MmModule_basic"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
# CHECK: SUCCESS - "MmModule_basic2"
|
||||
# CHECK: PASS - "MmModule_basic2"
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic2(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
@ -36,7 +36,7 @@ def MmModule_basic2(module, tu: TestUtils):
|
|||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results)
|
||||
report_results(results, set())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -25,7 +25,7 @@ class MmModule(torch.nn.Module):
|
|||
return 3
|
||||
|
||||
|
||||
# CHECK: FAILURE - "MmModule_basic"
|
||||
# CHECK: FAIL - "MmModule_basic"
|
||||
# CHECK: compilation error
|
||||
# Assume that the diagnostic from the TorchScript compiler will at least contain
|
||||
# the offending "return 3".
|
||||
|
@ -38,7 +38,7 @@ def MmModule_basic(module, tu: TestUtils):
|
|||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results, verbose=True)
|
||||
report_results(results, set(), verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -26,7 +26,7 @@ class MmModule(torch.nn.Module):
|
|||
|
||||
|
||||
# TODO: Refine error messages.
|
||||
# CHECK: FAILURE - "MmModule_basic"
|
||||
# CHECK: FAIL - "MmModule_basic"
|
||||
# CHECK: @ trace item #0 - call to "forward"
|
||||
# CHECK: @ output #0
|
||||
# CHECK: ERROR: values mismatch
|
||||
|
@ -40,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
|
|||
def main():
|
||||
config = TorchScriptTestConfig()
|
||||
results = run_tests(GLOBAL_TEST_REGISTRY, config)
|
||||
report_results(results, verbose=True)
|
||||
report_results(results, set(), verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
def is_iree_enabled():
|
||||
try:
|
||||
import iree.runtime
|
||||
import iree.compiler
|
||||
except:
|
||||
return False
|
||||
return True
|
|
@ -0,0 +1,45 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import abc
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from mlir.ir import Module
|
||||
|
||||
# A type shared between the result of `NpcompBackend.compile` and the input
|
||||
# to `NpcompBackend.load`. Each backend will likely have a different definition
|
||||
# of this type.
|
||||
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||
|
||||
# A wrapper around a backend-specific loaded program representation
|
||||
# that uniformly translates the `x.method(...)` interface expected of
|
||||
# Torch modules into appropriate lower-level operations.
|
||||
Invoker = TypeVar('Invoker')
|
||||
|
||||
|
||||
class NpcompBackend(abc.ABC):
|
||||
"""The interface to an npcomp backend.
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def compile(self, module: Module) -> CompiledArtifact:
|
||||
"""Compile the provided MLIR module into a compiled artifact.
|
||||
|
||||
The module adheres to the npcomp backend contract
|
||||
(see the VerifyBackendContract pass).
|
||||
|
||||
The compiled artifact can be any type, but must be correctly
|
||||
interpreted by the `load` method.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||
"""Load the compiled artifact into a uniformly invokable form.
|
||||
|
||||
The compiled artifact is the result of a previous call to `compile`.
|
||||
|
||||
See the description of `Invoker` for the requirements on the returned
|
||||
type.
|
||||
"""
|
|
@ -5,6 +5,7 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
|
@ -12,8 +13,10 @@ from npcomp.compiler.utils import logging
|
|||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
|
||||
from .abc import NpcompBackend
|
||||
|
||||
__all__ = [
|
||||
"CompilerBackend",
|
||||
"IreeNpcompBackend",
|
||||
]
|
||||
|
||||
PREPARE_FOR_IREE_PASSES = (
|
||||
|
@ -34,6 +37,8 @@ class IreeModuleInvoker:
|
|||
|
||||
def invoke(*args):
|
||||
results = self._iree_module[function_name](*args)
|
||||
if isinstance(results, np.ndarray):
|
||||
return results
|
||||
if len(results) == 1:
|
||||
# De-tuple.
|
||||
return results[0]
|
||||
|
@ -58,7 +63,7 @@ class TorchIreeModuleInvoker(IreeModuleInvoker):
|
|||
return invoke
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
class IreeNpcompBackend(NpcompBackend):
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -67,9 +72,8 @@ class CompilerBackend:
|
|||
|
||||
def compile(self, imported_module: Module):
|
||||
"""Compiles an imported module, with a flat list of functions.
|
||||
The module is expected to be in "TCP + scalar code" form.
|
||||
TODO: More clearly define the backend contract. Generally this will
|
||||
extend to support globals, lists, and other stuff.
|
||||
The module is expected to conform to the npcomp backend contract.
|
||||
See the VerifyBackendContract pass for more details.
|
||||
|
||||
Args:
|
||||
imported_module: The MLIR module consisting of funcs in the torch
|
||||
|
@ -97,12 +101,13 @@ class CompilerBackend:
|
|||
# Backend.
|
||||
binary = ireec.compile_str(str(imported_module),
|
||||
target_backends=["dylib-llvm-aot"])
|
||||
iree_config = ireert.Config(driver_name="dylib")
|
||||
|
||||
iree_module = ireert.load_module(ireert.VmModule.from_flatbuffer(binary),
|
||||
config=iree_config)
|
||||
return iree_module
|
||||
return binary
|
||||
|
||||
def load(self, iree_module) -> TorchIreeModuleInvoker:
|
||||
"""Loads a compiled artifact into the runtime."""
|
||||
return TorchIreeModuleInvoker(iree_module)
|
||||
vm_module = ireert.VmModule.from_flatbuffer(iree_module)
|
||||
|
||||
iree_config = ireert.Config(driver_name="dylib")
|
||||
ctx = ireert.SystemContext(config=iree_config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
return TorchIreeModuleInvoker(ctx.modules.module)
|
||||
|
|
|
@ -10,10 +10,11 @@ from mlir.ir import *
|
|||
from mlir.passmanager import *
|
||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||
from npcomp.compiler.utils import logging
|
||||
from .abc import NpcompBackend
|
||||
|
||||
__all__ = [
|
||||
"is_enabled",
|
||||
"CompilerBackend",
|
||||
"RefjitNpcompBackend",
|
||||
]
|
||||
|
||||
# Re-export.
|
||||
|
@ -34,7 +35,7 @@ class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker):
|
|||
return invoke
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
class RefjitNpcompBackend(NpcompBackend):
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
src_dir="$(realpath $(dirname $0)/..)"
|
||||
|
||||
cd "$src_dir"
|
||||
source .env
|
||||
python -m frontends.pytorch.e2e_testing.torchscript.main "$@"
|
Loading…
Reference in New Issue