Make pytorch/backend/refjit.py a bit tidier

- Print out initial PyTorch IR.
- Rename ambiguous "frontend IR" to "TCF IR".
- Add newlines to prints
- Rename FRONTEND_PASSES to TORCH_TO_TCF_PASSES
pull/122/head
Sean Silva 2020-11-20 15:07:34 -08:00
parent 32b2dc6ce7
commit ec1336a8a3
1 changed files with 12 additions and 5 deletions

View File

@ -14,8 +14,12 @@ __all__ = [
"CompilerBackend",
]
FRONTEND_PASSES = ("func(aten-recognize-kernels)", "func(convert-aten-to-tcf)",
"numpy-public-functions-to-tensor", "canonicalize")
TORCH_TO_TCF_PASSES = (
"func(aten-recognize-kernels)",
"func(convert-aten-to-tcf)",
"numpy-public-functions-to-tensor",
"canonicalize",
)
# Re-export.
is_enabled = refjit_backend.is_enabled
@ -43,11 +47,14 @@ class CompilerBackend:
"""
# TODO: Once transitioned to new Python API, don't reparse the module.
with Context() as context:
if self._debug:
logging.debug("Initial PyTorch IR:\n{}", imported_module)
# Frontend.
pm = PassManager.parse(",".join(FRONTEND_PASSES))
pm = PassManager.parse(",".join(TORCH_TO_TCF_PASSES))
pm.run(imported_module)
if self._debug:
logging.debug("Frontend IR:{}", imported_module)
logging.debug("TCF IR:\n{}", imported_module)
# Backend.
# Note that this is a separate pass manager purely to aid in debugging.
@ -55,7 +62,7 @@ class CompilerBackend:
self._refjit.build_backend_compilation_pipeline(pm)
pm.run(imported_module)
if self._debug:
logging.debug("Backend IR:{}", imported_module)
logging.debug("Backend IR:\n{}", imported_module)
jit_module = self._refjit.JITModule.from_compiled_module(
imported_module, refjit_backend.get_runtime_libs())