# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. from enum import Enum from io import StringIO import os import sys import tempfile from typing import Union, List import torch from torch_mlir.passmanager import PassManager from torch_mlir.ir import StringAttr class TensorPlaceholder: """A class that represents a formal parameter of a given shape and dtype. This class can be constructed explicitly from a shape and dtype: ```python placeholder = TensorPlaceholder([3, 4], torch.float32) ``` This class can also be constructed from a `torch.Tensor` which is already known to be a valid input to the function. In this case, a set of dynamic axes are allowed to be specified. ```python placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1]) # Equivalent to `TensorPlaceholder([3, -1], torch.float32)` ``` """ def __init__(self, shape: List[int], dtype: torch.dtype): """Create a tensor with shape `shape` and dtype `dtype`. Args: shape: The shape of the tensor. A size of `-1` indicates that the dimension has an unknown size. dtype: The dtype of the tensor. """ self.shape = shape self.dtype = dtype @staticmethod def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): """Create a tensor placeholder that is like the given tensor. Args: tensor: The tensor to create a placeholder for. dynamic_axes: A list of dynamic axes. If specified, the compiled module will allow those axes to be any size at runtime. """ if dynamic_axes is None: dynamic_axes = [] shape = [] for i, dim in enumerate(tensor.shape): if i in dynamic_axes: shape.append(-1) else: shape.append(dim) return TensorPlaceholder(shape, tensor.dtype) def get_module_name_for_debug_dump(module): """Gets a name suitable for a debug dump. The name is not guaranteed to be unique. """ if not "torch.debug_module_name" in module.operation.attributes: return "UnnammedModule" return StringAttr(module.operation.attributes["torch.debug_module_name"]).value class TorchMlirCompilerError(Exception): pass def run_pipeline_with_repro_report( module, pipeline: str, description: str, enable_ir_printing: bool = False ): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) original_stderr = sys.stderr try: sys.stderr = StringIO() asm_for_error_report = module.operation.get_asm( large_elements_limit=10, enable_debug_info=True ) # Lower module in place to make it ready for compiler backends. with module.context as ctx: # TODO(#3506): Passes can emit errors but not signal failure, # which causes a native assert. ctx.emit_error_diagnostics = True pm = PassManager.parse(pipeline) if enable_ir_printing: ctx.enable_multithreading(False) pm.enable_ir_printing() pm.run(module.operation) except Exception as e: # TODO: More robust. # - don't arbitrarily clutter up /tmp. When a test suite has many # tests, this can be a big disk cost (also, /tmp/ is frequently a # RAM fs, which increases worries about capacity). # - don't have colliding filenames (hard to do without cluttering # up /tmp) # - if we do have have colliding filenames, writes should at least # avoid being racy. filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") with open(filename, "w") as f: f.write(asm_for_error_report) debug_options = "-mlir-print-ir-after-all -mlir-disable-threading" # Put something descriptive here even if description is empty. description = description or f"{module_name} compile" message = f"""\ {description} failed with the following diagnostics: {sys.stderr.getvalue()} python exception: {e} For Torch-MLIR developers, the error can be reproduced with: $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. """ trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")]) raise TorchMlirCompilerError(trimmed_message) from None finally: sys.stderr = original_stderr class OutputType(Enum): # Output torch dialect in backend form. When converting from TorchDynamo, # this comes after some decomposition and reduce op variants passes are # applied to the raw torch dialect. When converting from TorchScript, this # comes after some cleanup passes which attempt to de-alias, decompose and infer shapes. # These should be roughly the same level of abstraction since those # steps are done within PyTorch itself when coming directly from Dynamo/FX. TORCH = "torch" # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and # `arith` ops (and also `math` and `tm_tensor`). It can be thought of # as taking the `TORCH` output type and lowering it so that tensor # computations are done with `linalg`-on-tensors ops. LINALG_ON_TENSORS = "linalg-on-tensors" # This output type consists of `tosa` dialect ops. It can be thought of # as taking the `TORCH` output type and lowering it to TOSA. TOSA = "tosa" # This output type consists of `stablehlo` dialect ops. It can be thought of # as taking the `TORCH` output type and lowering it to StableHLO. STABLEHLO = "stablehlo" # Raw output of the JIT IR importer in the TorchScript frontend or that of # the FX IR importer in the TorchDynamo frontend. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. RAW = "raw" @staticmethod def get(spec: Union[str, "OutputType"]) -> "OutputType": """Gets an OutputType from allowed way to specify one. Args: spec: An OutputType instance or the case-insensitive name of one of the enum values. Returns: An OutputType instance. """ if isinstance(spec, OutputType): return spec spec = spec.upper().replace("-", "_") if spec not in OutputType.__members__: raise ValueError( f"For output_type= argument, expected one of: " f"{', '.join(OutputType.__members__.keys())}" ) return OutputType[spec] def lower_mlir_module(verbose, output_type, module): if verbose: print("\n====================") print("Torch Backend IR") print(module) if output_type == OutputType.TORCH: return module if output_type == OutputType.TOSA: run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-tosa-backend-pipeline)", "Lowering Torch Backend IR -> TOSA Backend IR", ) if verbose: print("\n====================") print("TOSA Backend IR") print(module) return module if output_type == OutputType.LINALG_ON_TENSORS: run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) if verbose: print("\n====================") print("LINALG Backend IR") print(module) return module elif output_type == OutputType.STABLEHLO: run_pipeline_with_repro_report( module, "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", "Lowering Torch Backend IR -> StableHLO Backend IR", ) if verbose: print("\n====================") print("StableHLO Backend IR") print(module) return module raise Exception(f"Unknown OutputType: {output_type}")