From 832899817210ce506e9be9888cb2f7d2a5b59630 Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Wed, 20 Dec 2023 22:08:21 +0100 Subject: [PATCH] Allow printing all IR in `torch_mlir.compile` (#2669) This PR adds the `enable_ir_printing` option to `torch_mlir.compile`, which can be used to print the IR for all intermediate passes. When running the added test file via: ```shell $ python test/python/compile.py 2> tiny.stderr ``` the file `tiny.stderr` is about 700 KB. --- projects/pt1/python/torch_mlir/__init__.py | 12 +++++-- .../pt1/python/torch_mlir/compiler_utils.py | 8 +++-- test/python/compile.py | 34 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 test/python/compile.py diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/__init__.py index 8bbcce994..1cf1aa0e0 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/__init__.py @@ -319,7 +319,8 @@ def compile(model: torch.nn.Module, backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], verbose: bool = False, - use_make_fx: bool = False): + use_make_fx: bool = False, + enable_ir_printing: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -348,7 +349,13 @@ def compile(model: torch.nn.Module, into the abstract interpretation library. See `docs/adding_abstract_interpretation_functions.md` for more info on the format the functions should have. - verbose: If true, print extra information about the conversion. + verbose: If true, print extra information about the conversion to + stdout. + enable_ir_printing: If true, print the IR before and after each pass to + stderr. This is equivalent to setting MLIR's `-print-ir-after-all` + flag. Note that this can easily generate many gigabytes of text, + so make sure to pipe stderr to a file (for example, run + `python tinymodel.py 2> tinymodel.stderr` on Linux). Returns: An MLIR module that contains the converted model in the specified @@ -452,6 +459,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", "Lowering TorchScript IR -> Torch Backend IR", + enable_ir_printing=enable_ir_printing, ) return _lower_mlir_module(verbose, output_type, mb.module) diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py index 56e250e16..3a64473de 100644 --- a/projects/pt1/python/torch_mlir/compiler_utils.py +++ b/projects/pt1/python/torch_mlir/compiler_utils.py @@ -27,7 +27,8 @@ class TorchMlirCompilerError(Exception): def run_pipeline_with_repro_report(module, pipeline: str, - description: str): + description: str, + enable_ir_printing: bool = False): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) try: @@ -36,8 +37,11 @@ def run_pipeline_with_repro_report(module, asm_for_error_report = module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) # Lower module in place to make it ready for compiler backends. - with module.context: + with module.context as ctx: pm = PassManager.parse(pipeline) + if enable_ir_printing: + ctx.enable_multithreading(False) + pm.enable_ir_printing() pm.run(module.operation) except Exception as e: # TODO: More robust. diff --git a/test/python/compile.py b/test/python/compile.py new file mode 100644 index 000000000..fc2917e9c --- /dev/null +++ b/test/python/compile.py @@ -0,0 +1,34 @@ +# RUN: %PYTHON -s %s 2>&1 | FileCheck %s + +import gc +import sys +import torch +import torch_mlir + + +def run_test(f): + print("TEST:", f.__name__, file=sys.stderr) + f() + gc.collect() + + +class TinyModel(torch.nn.Module): + def __init__(self): + super(TinyModel, self).__init__() + + self.linear = torch.nn.Linear(20, 30) + + def forward(self, x): + x = self.linear(x) + return x + + +# CHECK-LABEL: TEST: test_enable_ir_printing +@run_test +def test_enable_ir_printing(): + torch_mlir.compile(TinyModel(), + torch.ones(1, 3, 20, 20), + output_type="linalg-on-tensors", + enable_ir_printing=True) +# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) +# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {