Allow printing all IR in `torch_mlir.compile` (#2669)

This PR adds the `enable_ir_printing` option to `torch_mlir.compile`,
which can be used to print the IR for all intermediate passes.

When running the added test file via:
```shell
$ python test/python/compile.py 2> tiny.stderr
```
the file `tiny.stderr` is about 700 KB.
pull/2411/merge
Rik Huijzer 2023-12-20 22:08:21 +01:00 committed by GitHub
parent 11cc92d4ab
commit 8328998172
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 4 deletions

View File

@ -319,7 +319,8 @@ def compile(model: torch.nn.Module,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library: Iterable[Callable] = [],
verbose: bool = False,
use_make_fx: bool = False):
use_make_fx: bool = False,
enable_ir_printing: bool = False):
"""Convert a PyTorch model to MLIR.
Args:
@ -348,7 +349,13 @@ def compile(model: torch.nn.Module,
into the abstract interpretation library. See
`docs/adding_abstract_interpretation_functions.md` for more info
on the format the functions should have.
verbose: If true, print extra information about the conversion.
verbose: If true, print extra information about the conversion to
stdout.
enable_ir_printing: If true, print the IR before and after each pass to
stderr. This is equivalent to setting MLIR's `-print-ir-after-all`
flag. Note that this can easily generate many gigabytes of text,
so make sure to pipe stderr to a file (for example, run
`python tinymodel.py 2> tinymodel.stderr` on Linux).
Returns:
An MLIR module that contains the converted model in the specified
@ -452,6 +459,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
mb.module,
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
"Lowering TorchScript IR -> Torch Backend IR",
enable_ir_printing=enable_ir_printing,
)
return _lower_mlir_module(verbose, output_type, mb.module)

View File

@ -27,7 +27,8 @@ class TorchMlirCompilerError(Exception):
def run_pipeline_with_repro_report(module,
pipeline: str,
description: str):
description: str,
enable_ir_printing: bool = False):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
try:
@ -36,8 +37,11 @@ def run_pipeline_with_repro_report(module,
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
# Lower module in place to make it ready for compiler backends.
with module.context:
with module.context as ctx:
pm = PassManager.parse(pipeline)
if enable_ir_printing:
ctx.enable_multithreading(False)
pm.enable_ir_printing()
pm.run(module.operation)
except Exception as e:
# TODO: More robust.

View File

@ -0,0 +1,34 @@
# RUN: %PYTHON -s %s 2>&1 | FileCheck %s
import gc
import sys
import torch
import torch_mlir
def run_test(f):
print("TEST:", f.__name__, file=sys.stderr)
f()
gc.collect()
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.linear = torch.nn.Linear(20, 30)
def forward(self, x):
x = self.linear(x)
return x
# CHECK-LABEL: TEST: test_enable_ir_printing
@run_test
def test_enable_ir_printing():
torch_mlir.compile(TinyModel(),
torch.ones(1, 3, 20, 20),
output_type="linalg-on-tensors",
enable_ir_printing=True)
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {