Add IREE support in TorchScript e2e tests.

- Add support for "expected failures" in test reporting. The new error
  reports look like
  [this](https://gist.github.com/silvasean/6ffd95e1d55302b699673da201da210d).
  - We will now be able to put these tests into CI, since the harness
    understand which tests are expected to pass and fail.
- Refactor RefBackendTestConfig to NpcompBackendTestConfig which
  supports both RefBackend and IREE.
- Add instructions for installing IREE dependencies (both from packages
  and for local builds of IREE)
- Add `tools/torchscript_e2e_test.sh` for invoking the e2e test
  harness (this makes invoking a bit easier, as it doesn't rely on a
  loose Python invocation).
pull/243/head
Sean Silva 2021-06-30 14:13:21 -07:00
parent 79928cd2dd
commit d5108b9dc1
22 changed files with 284 additions and 64 deletions

View File

@ -164,6 +164,37 @@ cd /src/mlir-npcomp
cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch
``` ```
### IREE Backend (from IREE packages)
```shell
# We currently track and require the latest snapshot.
pip3 install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases
# Run TorchScript E2E tests targeting IREE.
# Make sure to run "PyTorch Frontend" setup instructions first.
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
```
### IREE Backend (from local IREE build)
This configuration is useful for iterating locally, as you can
poke/debug/rebuild things in IREE.
```shell
# Locally build IREE.
# See https://google.github.io/iree/building-from-source/getting-started/
# Make sure IREE is configured with `-DIREE_BUILD_PYTHON_BINDINGS=ON`.
echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env
# Run TorchScript E2E tests targeting IREE.
# Make sure to run "PyTorch Frontend" setup instructions first.
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
```
### VSCode with a Docker Dev Image ### VSCode with a Docker Dev Image
#### Start a docker dev container based on our image #### Start a docker dev container based on our image

View File

@ -12,31 +12,39 @@ from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
# Available test configs. # Available test configs.
from torch_mlir.torchscript.e2e_test.configs import ( from torch_mlir.torchscript.e2e_test.configs import (
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
) )
from npcomp.compiler.pytorch.backend import is_iree_enabled
IREE_ENABLED = is_iree_enabled()
if IREE_ENABLED:
from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend
from npcomp.compiler.pytorch.backend.refjit import RefjitNpcompBackend
from .xfail_sets import XFAIL_SETS
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
# TODO: Use a relative import. # Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
# That requires invoking this file as a "package" though, which makes it # this script.
# not possible to just do `python main.py`. Instead, it requires something from . import basic
# like `python -m tochscript.main` which is annoying because it can only from . import vision_models
# be run from a specific directory. from . import mlp
# TODO: Find out best practices for python "main" files. from . import batchnorm
import basic from . import quantized_models
import vision_models from . import elementwise
import mlp
import batchnorm
import quantized_models
import elementwise
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend']
if IREE_ENABLED:
config_choices += ['iree']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config', parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'], choices=config_choices,
default='refbackend', default='refbackend',
help=''' help=f'''
Meaning of options: Meaning of options:
"refbackend": run through npcomp's RefBackend. "refbackend": run through npcomp's RefBackend.
"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''') ''')
@ -54,7 +62,9 @@ def main():
# Find the selected config. # Find the selected config.
if args.config == 'refbackend': if args.config == 'refbackend':
config = RefBackendTestConfig() config = NpcompBackendTestConfig(RefjitNpcompBackend())
elif args.config == 'iree':
config = NpcompBackendTestConfig(IreeNpcompBackend())
elif args.config == 'native_torch': elif args.config == 'native_torch':
config = NativeTorchTestConfig() config = NativeTorchTestConfig()
elif args.config == 'torchscript': elif args.config == 'torchscript':
@ -78,7 +88,8 @@ def main():
results = run_tests(tests, config) results = run_tests(tests, config)
# Report the test results. # Report the test results.
report_results(results, args.verbose) failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
sys.exit(1 if failed else 0)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -0,0 +1,29 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This file describes the sets of tests expected to fail for each config.
# This information is deliberately kept in a side table, rather than
# in-situ on the test, as a deliberate layering decision: tests should
# have unique keys to identify them and enable side tables of various kinds
# (this includes down into lower parts of the stack, where a side table
# might be used to keep more elaborate sets of testing configurations).
XFAIL_SETS = {}
# Lists of tests that fail to even reach the backends.
# These represent further work needed in npcomp to lower them properly
# to the backend contract.
_common_npcomp_lowering_xfails = {
'ResNet18Module_basic',
'QuantizedMLP_basic',
}
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
# https://github.com/google/iree/issues/6368
'MmDagModule_basic',
'Mlp1LayerModule_basic',
'Mlp2LayerModule_basic',
}

View File

@ -21,7 +21,7 @@ with mb.capture_function("cos", [input]) as f:
result = torch.cos(input) result = torch.cos(input)
f.returns([result]) f.returns([result])
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
logging.debug(f"Executing jit_module.cos") logging.debug(f"Executing jit_module.cos")

View File

@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0, arg1]) as f: with mb.capture_function("test", [arg0, arg1]) as f:
f.returns([fun(arg0, arg1)]) f.returns([fun(arg0, arg1)])
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1) test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)

View File

@ -22,7 +22,7 @@ with mb.capture_function("mm", [lhs, rhs]) as f:
result = torch.mm(lhs, rhs) result = torch.mm(lhs, rhs)
f.returns([result]) f.returns([result])
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs) test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)

View File

@ -28,7 +28,7 @@ with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
result = mul_maximum(lhs, rhs, threshold, bias) result = mul_maximum(lhs, rhs, threshold, bias)
f.returns([result]) f.returns([result])
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs, test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,

View File

@ -26,7 +26,7 @@ mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0]) as f: with mb.capture_function("test", [arg0]) as f:
f.returns([fun(arg0)]) f.returns([fun(arg0)])
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1) test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)

View File

@ -48,7 +48,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
mb.import_module(recursivescriptmodule._c, class_annotator) mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print() #mb.module.operation.print()
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled) jit_module = backend.load(compiled)

View File

@ -40,7 +40,7 @@ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
mb.import_module(recursivescriptmodule._c, class_annotator) mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print() #mb.module.operation.print()
backend = refjit.CompilerBackend() backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled) jit_module = backend.load(compiled)

View File

@ -34,13 +34,13 @@ class_annotator.exportNone(recursivescriptmodule._c._type())
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward']) class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [ class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
None, None,
([2, 3, -1], torch.float32) ([2, 3, -1], torch.float32, True)
]) ])
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator) mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print() #mb.module.operation.print()
backend = iree.CompilerBackend() backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled) jit_module = backend.load(compiled)

View File

@ -2,6 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information. # See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .ref_backend import RefBackendTestConfig from .npcomp_backend import NpcompBackendTestConfig
from .native_torch import NativeTorchTestConfig from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig from .torchscript import TorchScriptTestConfig

View File

@ -14,15 +14,32 @@ from mlir.passmanager import PassManager
import torch_mlir import torch_mlir
from npcomp.compiler.pytorch.backend import refjit from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir.torchscript.annotations import extract_annotations from torch_mlir.torchscript.annotations import extract_annotations
class PrettyErrorReportForIrOperation(object):
def __init__(self, module, module_name_for_ir_dump: str):
sys.stderr = StringIO()
self.filename_for_ir_dump = os.path.join(tempfile.gettempdir(),
module_name_for_ir_dump + '.mlir')
self.asm_for_error_report = module.get_asm(
large_elements_limit=10, enable_debug_info=True)
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
with open(self.filename_for_ir_dump, 'w') as f:
f.write(self.asm_for_error_report)
class RefBackendTestConfig(TestConfig): class NpcompBackendTestConfig(TestConfig):
"""TestConfig that just runs the torch.nn.Module through RefBackend.""" """Base class for TestConfig's that are implemented with npcomp.
def __init__(self):
This class handles all the common lowering that npcomp does before reaching
its backends.
"""
def __init__(self, backend: NpcompBackend):
super().__init__() super().__init__()
self.backend = refjit.CompilerBackend() self.backend = backend
def compile(self, program: torch.nn.Module) -> Any: def compile(self, program: torch.nn.Module) -> Any:
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
@ -79,14 +96,36 @@ $ npcomp-opt -{pipeline_str} {filename}
""") from None """) from None
finally: finally:
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
try:
sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
return self.backend.compile(mb.module) return self.backend.compile(mb.module)
except Exception as e:
filename = os.path.join(tempfile.gettempdir(),
scripted.original_name + '.mlir')
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
## Exception:
{e}
## Stderr:
{sys.stderr.getvalue()}
## Input IR has been saved in {filename}
""") from None
finally:
sys.stderr = sys.__stderr__
def run(self, artifact: Any, trace: Trace) -> Trace: def run(self, artifact: Any, trace: Trace) -> Trace:
jit_module = self.backend.load(artifact) backend_module = self.backend.load(artifact)
result: Trace = [] result: Trace = []
for item in trace: for item in trace:
numpy_inputs = [t.numpy() for t in item.inputs] numpy_inputs = [t.numpy() for t in item.inputs]
outputs = getattr(jit_module, item.symbol)(*numpy_inputs) outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
if isinstance(outputs, np.ndarray): if isinstance(outputs, np.ndarray):
outputs = [outputs] outputs = [outputs]
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs] torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]

View File

@ -5,8 +5,9 @@
Utilities for reporting the results of the test framework. Utilities for reporting the results of the test framework.
""" """
from typing import List, Optional from typing import List, Optional, Set
import collections
import io import io
import textwrap import textwrap
@ -70,7 +71,7 @@ class ValueReport:
assert self.failed assert self.failed
if self.value.size() != self.golden_value.size(): if self.value.size() != self.golden_value.size():
return self.context.format_error( return self.context.format_error(
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}' f'tensor shape mismatch: got {self.value.size()!r}, expected {self.golden_value.size()!r}'
) )
f = io.StringIO() f = io.StringIO()
p = lambda *x: print(*x, file=f) p = lambda *x: print(*x, file=f)
@ -167,17 +168,60 @@ class SingleTestReport:
return f.getvalue() return f.getvalue()
def report_results(results: List[TestResult], verbose: bool = False): def report_results(results: List[TestResult],
"""Provide a basic error report summarizing various TestResult's. expected_failures: Set[str],
verbose: bool = False):
"""Print a basic error report summarizing various TestResult's.
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
"lit" testing utility. See
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
The `expected_failures` set should contain the names of tests
(according to their `unique_name`) which are expected to fail.
The overall passing/failing status of the report requires these to fail
in order to succeed (this catches cases where things suddenly
start working).
If `verbose` is True, then provide an explanation of what failed. If `verbose` is True, then provide an explanation of what failed.
Returns True if the run resulted in any unexpected pass/fail behavior.
Otherwise False.
""" """
summary = collections.Counter()
for result in results: for result in results:
report = SingleTestReport(result, ErrorContext.empty()) report = SingleTestReport(result, ErrorContext.empty())
expected_failure = result.unique_name in expected_failures
if expected_failure:
if report.failed:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'XFAIL - "{result.unique_name}"' + error_str)
summary['XFAIL'] += 1
else:
print(f'XPASS - "{result.unique_name}"')
summary['XPASS'] += 1
else:
if not report.failed: if not report.failed:
print(f'SUCCESS - "{result.unique_name}"') print(f'PASS - "{result.unique_name}"')
summary['PASS'] += 1
else: else:
error_str = '' error_str = ''
if verbose: if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ') error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'FAILURE - "{result.unique_name}"' + error_str) print(f'FAIL - "{result.unique_name}"' + error_str)
summary['FAIL'] += 1
# Print a summary for easy scanning.
print('\nSummary:')
KEY_MEANINGS = {
'PASS': 'Passed',
'FAIL': 'Failed',
'XFAIL': 'Expectedly Failed',
'XPASS': 'Unexpectedly Passed',
}
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
if summary[key]:
print(f' {KEY_MEANINGS[key]}: {summary[key]}')
return summary['FAIL'] != 0 or summary['XPASS'] != 0

View File

@ -21,13 +21,13 @@ class MmModule(torch.nn.Module):
# TODO: Refine messages. # TODO: Refine messages.
# CHECK: SUCCESS - "MmModule_basic" # CHECK: PASS - "MmModule_basic"
@register_test_case(module_factory=lambda: MmModule()) @register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils): def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))
# CHECK: SUCCESS - "MmModule_basic2" # CHECK: PASS - "MmModule_basic2"
@register_test_case(module_factory=lambda: MmModule()) @register_test_case(module_factory=lambda: MmModule())
def MmModule_basic2(module, tu: TestUtils): def MmModule_basic2(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4)) module.forward(tu.rand(4, 4), tu.rand(4, 4))
@ -36,7 +36,7 @@ def MmModule_basic2(module, tu: TestUtils):
def main(): def main():
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config) results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results) report_results(results, set())
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -25,7 +25,7 @@ class MmModule(torch.nn.Module):
return 3 return 3
# CHECK: FAILURE - "MmModule_basic" # CHECK: FAIL - "MmModule_basic"
# CHECK: compilation error # CHECK: compilation error
# Assume that the diagnostic from the TorchScript compiler will at least contain # Assume that the diagnostic from the TorchScript compiler will at least contain
# the offending "return 3". # the offending "return 3".
@ -38,7 +38,7 @@ def MmModule_basic(module, tu: TestUtils):
def main(): def main():
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config) results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, verbose=True) report_results(results, set(), verbose=True)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -26,7 +26,7 @@ class MmModule(torch.nn.Module):
# TODO: Refine error messages. # TODO: Refine error messages.
# CHECK: FAILURE - "MmModule_basic" # CHECK: FAIL - "MmModule_basic"
# CHECK: @ trace item #0 - call to "forward" # CHECK: @ trace item #0 - call to "forward"
# CHECK: @ output #0 # CHECK: @ output #0
# CHECK: ERROR: values mismatch # CHECK: ERROR: values mismatch
@ -40,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils):
def main(): def main():
config = TorchScriptTestConfig() config = TorchScriptTestConfig()
results = run_tests(GLOBAL_TEST_REGISTRY, config) results = run_tests(GLOBAL_TEST_REGISTRY, config)
report_results(results, verbose=True) report_results(results, set(), verbose=True)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -0,0 +1,7 @@
def is_iree_enabled():
try:
import iree.runtime
import iree.compiler
except:
return False
return True

View File

@ -0,0 +1,45 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import abc
from typing import TypeVar
import torch
from mlir.ir import Module
# A type shared between the result of `NpcompBackend.compile` and the input
# to `NpcompBackend.load`. Each backend will likely have a different definition
# of this type.
CompiledArtifact = TypeVar('CompiledArtifact')
# A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker')
class NpcompBackend(abc.ABC):
"""The interface to an npcomp backend.
"""
@abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact.
The module adheres to the npcomp backend contract
(see the VerifyBackendContract pass).
The compiled artifact can be any type, but must be correctly
interpreted by the `load` method.
"""
@abc.abstractmethod
def load(self, artifact: CompiledArtifact) -> Invoker:
"""Load the compiled artifact into a uniformly invokable form.
The compiled artifact is the result of a previous call to `compile`.
See the description of `Invoker` for the requirements on the returned
type.
"""

View File

@ -5,6 +5,7 @@
import os import os
import torch import torch
import numpy as np
from mlir.ir import * from mlir.ir import *
from mlir.passmanager import * from mlir.passmanager import *
@ -12,8 +13,10 @@ from npcomp.compiler.utils import logging
import iree.runtime as ireert import iree.runtime as ireert
import iree.compiler as ireec import iree.compiler as ireec
from .abc import NpcompBackend
__all__ = [ __all__ = [
"CompilerBackend", "IreeNpcompBackend",
] ]
PREPARE_FOR_IREE_PASSES = ( PREPARE_FOR_IREE_PASSES = (
@ -34,6 +37,8 @@ class IreeModuleInvoker:
def invoke(*args): def invoke(*args):
results = self._iree_module[function_name](*args) results = self._iree_module[function_name](*args)
if isinstance(results, np.ndarray):
return results
if len(results) == 1: if len(results) == 1:
# De-tuple. # De-tuple.
return results[0] return results[0]
@ -58,7 +63,7 @@ class TorchIreeModuleInvoker(IreeModuleInvoker):
return invoke return invoke
class CompilerBackend: class IreeNpcompBackend(NpcompBackend):
"""Main entry-point for the backend.""" """Main entry-point for the backend."""
def __init__(self): def __init__(self):
@ -67,9 +72,8 @@ class CompilerBackend:
def compile(self, imported_module: Module): def compile(self, imported_module: Module):
"""Compiles an imported module, with a flat list of functions. """Compiles an imported module, with a flat list of functions.
The module is expected to be in "TCP + scalar code" form. The module is expected to conform to the npcomp backend contract.
TODO: More clearly define the backend contract. Generally this will See the VerifyBackendContract pass for more details.
extend to support globals, lists, and other stuff.
Args: Args:
imported_module: The MLIR module consisting of funcs in the torch imported_module: The MLIR module consisting of funcs in the torch
@ -97,12 +101,13 @@ class CompilerBackend:
# Backend. # Backend.
binary = ireec.compile_str(str(imported_module), binary = ireec.compile_str(str(imported_module),
target_backends=["dylib-llvm-aot"]) target_backends=["dylib-llvm-aot"])
iree_config = ireert.Config(driver_name="dylib") return binary
iree_module = ireert.load_module(ireert.VmModule.from_flatbuffer(binary),
config=iree_config)
return iree_module
def load(self, iree_module) -> TorchIreeModuleInvoker: def load(self, iree_module) -> TorchIreeModuleInvoker:
"""Loads a compiled artifact into the runtime.""" """Loads a compiled artifact into the runtime."""
return TorchIreeModuleInvoker(iree_module) vm_module = ireert.VmModule.from_flatbuffer(iree_module)
iree_config = ireert.Config(driver_name="dylib")
ctx = ireert.SystemContext(config=iree_config)
ctx.add_vm_module(vm_module)
return TorchIreeModuleInvoker(ctx.modules.module)

View File

@ -10,10 +10,11 @@ from mlir.ir import *
from mlir.passmanager import * from mlir.passmanager import *
from npcomp.compiler.generic.backend import refjit as refjit_backend from npcomp.compiler.generic.backend import refjit as refjit_backend
from npcomp.compiler.utils import logging from npcomp.compiler.utils import logging
from .abc import NpcompBackend
__all__ = [ __all__ = [
"is_enabled", "is_enabled",
"CompilerBackend", "RefjitNpcompBackend",
] ]
# Re-export. # Re-export.
@ -34,7 +35,7 @@ class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker):
return invoke return invoke
class CompilerBackend: class RefjitNpcompBackend(NpcompBackend):
"""Main entry-point for the backend.""" """Main entry-point for the backend."""
def __init__(self): def __init__(self):

View File

@ -0,0 +1,8 @@
#!/bin/bash
set -euo pipefail
src_dir="$(realpath $(dirname $0)/..)"
cd "$src_dir"
source .env
python -m frontends.pytorch.e2e_testing.torchscript.main "$@"