# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. from io import StringIO import os import sys import tempfile from torch_mlir.passmanager import PassManager from torch_mlir.ir import StringAttr def get_module_name_for_debug_dump(module): """Gets a name suitable for a debug dump. The name is not guaranteed to be unique. """ if not "torch.debug_module_name" in module.operation.attributes: return "UnnammedModule" return StringAttr(module.operation.attributes["torch.debug_module_name"]).value def run_pipeline_with_repro_report(module, pipeline: str, description: str): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) try: original_stderr = sys.stderr sys.stderr = StringIO() asm_for_error_report = module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) # Lower module in place to make it ready for compiler backends. with module.context: pm = PassManager.parse(pipeline) pm.run(module) except Exception as e: # TODO: More robust. # - don't arbitrarily clutter up /tmp. When a test suite has many # tests, this can be a big disk cost (also, /tmp/ is frequently a # RAM fs, which increases worries about capacity). # - don't have colliding filenames (hard to do without cluttering # up /tmp) # - if we do have have colliding filenames, writes should at least # avoid being racy. filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") with open(filename, 'w') as f: f.write(asm_for_error_report) debug_options="-mlir-print-ir-after-all -mlir-disable-threading" raise Exception(f""" {description} failed with the following diagnostics: {sys.stderr.getvalue()} Error can be reproduced with: $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. """) from None finally: sys.stderr = original_stderr