mirror of https://github.com/llvm/torch-mlir
Extract the Python APIs in the pt1 dir back to the root (#3237)
parent
9a12a093a6
commit
944a6df611
|
@ -20,7 +20,6 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
|||
SOURCES
|
||||
torchscript.py
|
||||
_dynamo_fx_importer.py
|
||||
compiler_utils.py
|
||||
dynamo.py
|
||||
_version.py
|
||||
)
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from io import StringIO
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
|
||||
def get_module_name_for_debug_dump(module):
|
||||
"""Gets a name suitable for a debug dump.
|
||||
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
|
||||
class TorchMlirCompilerError(Exception):
|
||||
pass
|
||||
|
||||
def run_pipeline_with_repro_report(module,
|
||||
pipeline: str,
|
||||
description: str,
|
||||
enable_ir_printing: bool = False):
|
||||
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||
module_name = get_module_name_for_debug_dump(module)
|
||||
try:
|
||||
original_stderr = sys.stderr
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context as ctx:
|
||||
pm = PassManager.parse(pipeline)
|
||||
if enable_ir_printing:
|
||||
ctx.enable_multithreading(False)
|
||||
pm.enable_ir_printing()
|
||||
pm.run(module.operation)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
debug_options="-mlir-print-ir-after-all -mlir-disable-threading"
|
||||
# Put something descriptive here even if description is empty.
|
||||
description = description or f"{module_name} compile"
|
||||
|
||||
message = f"""\
|
||||
{description} failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
python exception: {e}
|
||||
|
||||
For Torch-MLIR developers, the error can be reproduced with:
|
||||
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||
Add '{debug_options}' to get the IR dump for debugging purpose.
|
||||
"""
|
||||
trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')])
|
||||
raise TorchMlirCompilerError(trimmed_message) from None
|
||||
finally:
|
||||
sys.stderr = original_stderr
|
|
@ -17,65 +17,15 @@ import torch.fx
|
|||
from torch_mlir.dynamo import _get_decomposition_table
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
OutputType,
|
||||
lower_mlir_module
|
||||
)
|
||||
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||
|
||||
|
||||
class OutputType(Enum):
|
||||
"""The kind of output that `torchscript.compile` can produce.
|
||||
|
||||
In MLIR terminology, this describes the mix of dialects that will be
|
||||
produced by the conversion process.
|
||||
|
||||
In user-facing API's, this type can always be passed interchangeably with an
|
||||
appropriate string specifying the output type. The allowed strings are
|
||||
the set of enum vales, allowed to be case insensitive and with `-` allowed
|
||||
in place of `_`. The `OutputType.get` static method can be used to convert
|
||||
from a string to an `OutputType` instance.
|
||||
"""
|
||||
|
||||
# This output type consists of `torch` dialect ops that have been converted
|
||||
# maximally to value semantics, decomposed, and shapes have been inferred.
|
||||
TORCH = "torch"
|
||||
|
||||
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
|
||||
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it so that tensor
|
||||
# computations are done with `linalg`-on-tensors ops.
|
||||
LINALG_ON_TENSORS = "linalg-on-tensors"
|
||||
|
||||
# This output type consists of `tosa` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to TOSA.
|
||||
TOSA = "tosa"
|
||||
|
||||
# This output type consists of `stablehlo` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to StableHLO.
|
||||
STABLEHLO = "stablehlo"
|
||||
|
||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||
# for end-users, but can be convenient for development or reporting bugs.
|
||||
RAW = "raw"
|
||||
|
||||
@staticmethod
|
||||
def get(spec: Union[str, "OutputType"]) -> "OutputType":
|
||||
"""Gets an OutputType from allowed way to specify one.
|
||||
|
||||
Args:
|
||||
spec: An OutputType instance or the case-insensitive name of one of the
|
||||
enum values.
|
||||
Returns:
|
||||
An OutputType instance.
|
||||
"""
|
||||
if isinstance(spec, OutputType):
|
||||
return spec
|
||||
spec = spec.upper().replace("-", "_")
|
||||
if spec not in OutputType.__members__:
|
||||
raise ValueError(f"For output_type= argument, expected one of: "
|
||||
f"{', '.join(OutputType.__members__.keys())}")
|
||||
return OutputType[spec]
|
||||
|
||||
|
||||
class TensorPlaceholder:
|
||||
"""A class that represents a formal parameter of a given shape and dtype.
|
||||
|
||||
|
@ -270,49 +220,6 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra
|
|||
return ""
|
||||
|
||||
|
||||
def _lower_mlir_module(verbose, output_type, module):
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("Torch Backend IR")
|
||||
print(module)
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
return module
|
||||
|
||||
if output_type == OutputType.TOSA:
|
||||
run_pipeline_with_repro_report(
|
||||
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("TOSA Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
|
||||
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("LINALG Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
|
||||
elif output_type == OutputType.STABLEHLO:
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("StableHLO Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
||||
|
||||
|
||||
def compile(model: torch.nn.Module,
|
||||
example_args: _example_args,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
|
@ -464,4 +371,4 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
enable_ir_printing=enable_ir_printing,
|
||||
)
|
||||
|
||||
return _lower_mlir_module(verbose, output_type, mb.module)
|
||||
return lower_mlir_module(verbose, output_type, mb.module)
|
||||
|
|
|
@ -12,12 +12,13 @@ from torch.export.graph_signature import OutputSpec, OutputKind
|
|||
from torch.export import ExportedProgram
|
||||
|
||||
from torch_mlir import fx
|
||||
from torch_mlir.torchscript import (
|
||||
_example_args,
|
||||
OutputType,
|
||||
BACKEND_LEGAL_OPS,
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
_lower_mlir_module,
|
||||
lower_mlir_module,
|
||||
OutputType,
|
||||
)
|
||||
from torch_mlir.torchscript import (
|
||||
BACKEND_LEGAL_OPS,
|
||||
_canon_extra_library,
|
||||
)
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
|
@ -76,7 +77,7 @@ def jit(
|
|||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
return _lower_mlir_module(verbose, output_type, mlir_module)
|
||||
return lower_mlir_module(verbose, output_type, mlir_module)
|
||||
|
||||
|
||||
class FxImporterTestConfig(TestConfig):
|
||||
|
|
|
@ -15,14 +15,16 @@ from torch._functorch.aot_autograd import (
|
|||
set_model_name,
|
||||
)
|
||||
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
lower_mlir_module,
|
||||
OutputType,
|
||||
)
|
||||
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
|
||||
from torch_mlir.dynamo import _get_decomposition_table
|
||||
from torch_mlir.torchscript import (
|
||||
_example_args,
|
||||
OutputType,
|
||||
BACKEND_LEGAL_OPS,
|
||||
run_pipeline_with_repro_report,
|
||||
_lower_mlir_module,
|
||||
_canon_extra_library,
|
||||
)
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
|
@ -148,7 +150,7 @@ def jit(
|
|||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
return _lower_mlir_module(verbose, output_type, mlir_module)
|
||||
return lower_mlir_module(verbose, output_type, mlir_module)
|
||||
|
||||
|
||||
class TorchDynamoTestConfig(TestConfig):
|
||||
|
|
|
@ -4,11 +4,13 @@
|
|||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
lower_mlir_module,
|
||||
OutputType,
|
||||
)
|
||||
from torch_mlir.ir import *
|
||||
from torch_mlir.passmanager import *
|
||||
from torch_mlir.torchscript import OutputType
|
||||
from torch_mlir.torchscript import _lower_mlir_module
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
|
@ -58,7 +60,7 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
|||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
|
||||
imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
|
||||
compiled_module = self.refbackend.compile(imported_module)
|
||||
return compiled_module
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
|
|||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
compiler_utils.py
|
||||
fx.py
|
||||
extras/fx_decomp_util.py
|
||||
)
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Union
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
|
||||
def get_module_name_for_debug_dump(module):
|
||||
"""Gets a name suitable for a debug dump.
|
||||
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
|
||||
class TorchMlirCompilerError(Exception):
|
||||
pass
|
||||
|
||||
def run_pipeline_with_repro_report(module,
|
||||
pipeline: str,
|
||||
description: str,
|
||||
enable_ir_printing: bool = False):
|
||||
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||
module_name = get_module_name_for_debug_dump(module)
|
||||
original_stderr = sys.stderr
|
||||
try:
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context as ctx:
|
||||
pm = PassManager.parse(pipeline)
|
||||
if enable_ir_printing:
|
||||
ctx.enable_multithreading(False)
|
||||
pm.enable_ir_printing()
|
||||
pm.run(module.operation)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
debug_options="-mlir-print-ir-after-all -mlir-disable-threading"
|
||||
# Put something descriptive here even if description is empty.
|
||||
description = description or f"{module_name} compile"
|
||||
|
||||
message = f"""\
|
||||
{description} failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
python exception: {e}
|
||||
|
||||
For Torch-MLIR developers, the error can be reproduced with:
|
||||
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||
Add '{debug_options}' to get the IR dump for debugging purpose.
|
||||
"""
|
||||
trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')])
|
||||
raise TorchMlirCompilerError(trimmed_message) from None
|
||||
finally:
|
||||
sys.stderr = original_stderr
|
||||
|
||||
|
||||
class OutputType(Enum):
|
||||
|
||||
# Output torch dialect. When converting from FX, this will be immediately
|
||||
# after the import from FX to MLIR. When converting from torchscript,
|
||||
# this will come after some cleanup passes which attempt to de-alias,
|
||||
# decompose and infer shapes. These should be roughly the same level of
|
||||
# abstraction since those steps are done within PyTorch itself
|
||||
# when coming directly from Dynamo/FX.
|
||||
TORCH = "torch"
|
||||
|
||||
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
|
||||
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it so that tensor
|
||||
# computations are done with `linalg`-on-tensors ops.
|
||||
LINALG_ON_TENSORS = "linalg-on-tensors"
|
||||
|
||||
# This output type consists of `tosa` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to TOSA.
|
||||
TOSA = "tosa"
|
||||
|
||||
# This output type consists of `stablehlo` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to StableHLO.
|
||||
STABLEHLO = "stablehlo"
|
||||
|
||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||
# for end-users, but can be convenient for development or reporting bugs.
|
||||
RAW = "raw"
|
||||
|
||||
@staticmethod
|
||||
def get(spec: Union[str, "OutputType"]) -> "OutputType":
|
||||
"""Gets an OutputType from allowed way to specify one.
|
||||
|
||||
Args:
|
||||
spec: An OutputType instance or the case-insensitive name of one of the
|
||||
enum values.
|
||||
Returns:
|
||||
An OutputType instance.
|
||||
"""
|
||||
if isinstance(spec, OutputType):
|
||||
return spec
|
||||
spec = spec.upper().replace("-", "_")
|
||||
if spec not in OutputType.__members__:
|
||||
raise ValueError(f"For output_type= argument, expected one of: "
|
||||
f"{', '.join(OutputType.__members__.keys())}")
|
||||
return OutputType[spec]
|
||||
|
||||
|
||||
def lower_mlir_module(verbose, output_type, module):
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("Torch Backend IR")
|
||||
print(module)
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
return module
|
||||
|
||||
if output_type == OutputType.TOSA:
|
||||
run_pipeline_with_repro_report(
|
||||
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("TOSA Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
|
||||
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("LINALG Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
|
||||
elif output_type == OutputType.STABLEHLO:
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("StableHLO Backend IR")
|
||||
print(module)
|
||||
return module
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
Loading…
Reference in New Issue