mirror of https://github.com/llvm/torch-mlir
Allow printing all IR in `torch_mlir.compile` (#2669)
This PR adds the `enable_ir_printing` option to `torch_mlir.compile`, which can be used to print the IR for all intermediate passes. When running the added test file via: ```shell $ python test/python/compile.py 2> tiny.stderr ``` the file `tiny.stderr` is about 700 KB.pull/2411/merge
parent
11cc92d4ab
commit
8328998172
|
@ -319,7 +319,8 @@ def compile(model: torch.nn.Module,
|
|||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
extra_library: Iterable[Callable] = [],
|
||||
verbose: bool = False,
|
||||
use_make_fx: bool = False):
|
||||
use_make_fx: bool = False,
|
||||
enable_ir_printing: bool = False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
Args:
|
||||
|
@ -348,7 +349,13 @@ def compile(model: torch.nn.Module,
|
|||
into the abstract interpretation library. See
|
||||
`docs/adding_abstract_interpretation_functions.md` for more info
|
||||
on the format the functions should have.
|
||||
verbose: If true, print extra information about the conversion.
|
||||
verbose: If true, print extra information about the conversion to
|
||||
stdout.
|
||||
enable_ir_printing: If true, print the IR before and after each pass to
|
||||
stderr. This is equivalent to setting MLIR's `-print-ir-after-all`
|
||||
flag. Note that this can easily generate many gigabytes of text,
|
||||
so make sure to pipe stderr to a file (for example, run
|
||||
`python tinymodel.py 2> tinymodel.stderr` on Linux).
|
||||
|
||||
Returns:
|
||||
An MLIR module that contains the converted model in the specified
|
||||
|
@ -452,6 +459,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
mb.module,
|
||||
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
|
||||
"Lowering TorchScript IR -> Torch Backend IR",
|
||||
enable_ir_printing=enable_ir_printing,
|
||||
)
|
||||
|
||||
return _lower_mlir_module(verbose, output_type, mb.module)
|
||||
|
|
|
@ -27,7 +27,8 @@ class TorchMlirCompilerError(Exception):
|
|||
|
||||
def run_pipeline_with_repro_report(module,
|
||||
pipeline: str,
|
||||
description: str):
|
||||
description: str,
|
||||
enable_ir_printing: bool = False):
|
||||
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||
module_name = get_module_name_for_debug_dump(module)
|
||||
try:
|
||||
|
@ -36,8 +37,11 @@ def run_pipeline_with_repro_report(module,
|
|||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context:
|
||||
with module.context as ctx:
|
||||
pm = PassManager.parse(pipeline)
|
||||
if enable_ir_printing:
|
||||
ctx.enable_multithreading(False)
|
||||
pm.enable_ir_printing()
|
||||
pm.run(module.operation)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# RUN: %PYTHON -s %s 2>&1 | FileCheck %s
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
|
||||
def run_test(f):
|
||||
print("TEST:", f.__name__, file=sys.stderr)
|
||||
f()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class TinyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TinyModel, self).__init__()
|
||||
|
||||
self.linear = torch.nn.Linear(20, 30)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: test_enable_ir_printing
|
||||
@run_test
|
||||
def test_enable_ir_printing():
|
||||
torch_mlir.compile(TinyModel(),
|
||||
torch.ones(1, 3, 20, 20),
|
||||
output_type="linalg-on-tensors",
|
||||
enable_ir_printing=True)
|
||||
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
|
||||
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {
|
Loading…
Reference in New Issue