mirror of https://github.com/llvm/torch-mlir
Make pytorch/backend/refjit.py a bit tidier
- Print out initial PyTorch IR. - Rename ambiguous "frontend IR" to "TCF IR". - Add newlines to prints - Rename FRONTEND_PASSES to TORCH_TO_TCF_PASSESpull/122/head
parent
32b2dc6ce7
commit
ec1336a8a3
|
@ -14,8 +14,12 @@ __all__ = [
|
|||
"CompilerBackend",
|
||||
]
|
||||
|
||||
FRONTEND_PASSES = ("func(aten-recognize-kernels)", "func(convert-aten-to-tcf)",
|
||||
"numpy-public-functions-to-tensor", "canonicalize")
|
||||
TORCH_TO_TCF_PASSES = (
|
||||
"func(aten-recognize-kernels)",
|
||||
"func(convert-aten-to-tcf)",
|
||||
"numpy-public-functions-to-tensor",
|
||||
"canonicalize",
|
||||
)
|
||||
|
||||
# Re-export.
|
||||
is_enabled = refjit_backend.is_enabled
|
||||
|
@ -43,11 +47,14 @@ class CompilerBackend:
|
|||
"""
|
||||
# TODO: Once transitioned to new Python API, don't reparse the module.
|
||||
with Context() as context:
|
||||
if self._debug:
|
||||
logging.debug("Initial PyTorch IR:\n{}", imported_module)
|
||||
|
||||
# Frontend.
|
||||
pm = PassManager.parse(",".join(FRONTEND_PASSES))
|
||||
pm = PassManager.parse(",".join(TORCH_TO_TCF_PASSES))
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Frontend IR:{}", imported_module)
|
||||
logging.debug("TCF IR:\n{}", imported_module)
|
||||
|
||||
# Backend.
|
||||
# Note that this is a separate pass manager purely to aid in debugging.
|
||||
|
@ -55,7 +62,7 @@ class CompilerBackend:
|
|||
self._refjit.build_backend_compilation_pipeline(pm)
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("Backend IR:{}", imported_module)
|
||||
logging.debug("Backend IR:\n{}", imported_module)
|
||||
|
||||
jit_module = self._refjit.JITModule.from_compiled_module(
|
||||
imported_module, refjit_backend.get_runtime_libs())
|
||||
|
|
Loading…
Reference in New Issue