Extract the Python APIs in the pt1 dir back to the root (#3237)

pull/3242/head
penguin_wwy 2024-04-27 18:27:37 +08:00 committed by GitHub
parent 9a12a093a6
commit 944a6df611
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 192 additions and 189 deletions

View File

@ -20,7 +20,6 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
SOURCES
torchscript.py
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py
_version.py
)

View File

@ -1,75 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from io import StringIO
import os
import sys
import tempfile
from torch_mlir.passmanager import PassManager
from torch_mlir.ir import StringAttr
def get_module_name_for_debug_dump(module):
"""Gets a name suitable for a debug dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
class TorchMlirCompilerError(Exception):
pass
def run_pipeline_with_repro_report(module,
pipeline: str,
description: str,
enable_ir_printing: bool = False):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
try:
original_stderr = sys.stderr
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
# Lower module in place to make it ready for compiler backends.
with module.context as ctx:
pm = PassManager.parse(pipeline)
if enable_ir_printing:
ctx.enable_multithreading(False)
pm.enable_ir_printing()
pm.run(module.operation)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm_for_error_report)
debug_options="-mlir-print-ir-after-all -mlir-disable-threading"
# Put something descriptive here even if description is empty.
description = description or f"{module_name} compile"
message = f"""\
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
python exception: {e}
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
Add '{debug_options}' to get the IR dump for debugging purpose.
"""
trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')])
raise TorchMlirCompilerError(trimmed_message) from None
finally:
sys.stderr = original_stderr

View File

@ -17,65 +17,15 @@ import torch.fx
from torch_mlir.dynamo import _get_decomposition_table
from torch.fx.experimental.proxy_tensor import make_fx
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
OutputType,
lower_mlir_module
)
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
class OutputType(Enum):
"""The kind of output that `torchscript.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process.
In user-facing API's, this type can always be passed interchangeably with an
appropriate string specifying the output type. The allowed strings are
the set of enum vales, allowed to be case insensitive and with `-` allowed
in place of `_`. The `OutputType.get` static method can be used to convert
from a string to an `OutputType` instance.
"""
# This output type consists of `torch` dialect ops that have been converted
# maximally to value semantics, decomposed, and shapes have been inferred.
TORCH = "torch"
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = "linalg-on-tensors"
# This output type consists of `tosa` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa"
# This output type consists of `stablehlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = "raw"
@staticmethod
def get(spec: Union[str, "OutputType"]) -> "OutputType":
"""Gets an OutputType from allowed way to specify one.
Args:
spec: An OutputType instance or the case-insensitive name of one of the
enum values.
Returns:
An OutputType instance.
"""
if isinstance(spec, OutputType):
return spec
spec = spec.upper().replace("-", "_")
if spec not in OutputType.__members__:
raise ValueError(f"For output_type= argument, expected one of: "
f"{', '.join(OutputType.__members__.keys())}")
return OutputType[spec]
class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
@ -270,49 +220,6 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra
return ""
def _lower_mlir_module(verbose, output_type, module):
if verbose:
print("\n====================")
print("Torch Backend IR")
print(module)
if output_type == OutputType.TORCH:
return module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
print("TOSA Backend IR")
print(module)
return module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
if verbose:
print("\n====================")
print("LINALG Backend IR")
print(module)
return module
elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("StableHLO Backend IR")
print(module)
return module
raise Exception(f"Unknown OutputType: {output_type}")
def compile(model: torch.nn.Module,
example_args: _example_args,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
@ -464,4 +371,4 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
enable_ir_printing=enable_ir_printing,
)
return _lower_mlir_module(verbose, output_type, mb.module)
return lower_mlir_module(verbose, output_type, mb.module)

View File

@ -12,12 +12,13 @@ from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram
from torch_mlir import fx
from torch_mlir.torchscript import (
_example_args,
OutputType,
BACKEND_LEGAL_OPS,
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
_lower_mlir_module,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
@ -76,7 +77,7 @@ def jit(
"Lowering TorchFX IR -> Torch Backend IR",
)
return _lower_mlir_module(verbose, output_type, mlir_module)
return lower_mlir_module(verbose, output_type, mlir_module)
class FxImporterTestConfig(TestConfig):

View File

@ -15,14 +15,16 @@ from torch._functorch.aot_autograd import (
set_model_name,
)
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
from torch_mlir.dynamo import _get_decomposition_table
from torch_mlir.torchscript import (
_example_args,
OutputType,
BACKEND_LEGAL_OPS,
run_pipeline_with_repro_report,
_lower_mlir_module,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import (
@ -148,7 +150,7 @@ def jit(
"Lowering TorchFX IR -> Torch Backend IR",
)
return _lower_mlir_module(verbose, output_type, mlir_module)
return lower_mlir_module(verbose, output_type, mlir_module)
class TorchDynamoTestConfig(TestConfig):

View File

@ -4,11 +4,13 @@
# Also available under a BSD-style license. See LICENSE.
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.ir import *
from torch_mlir.passmanager import *
from torch_mlir.torchscript import OutputType
from torch_mlir.torchscript import _lower_mlir_module
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
@ -58,7 +60,7 @@ class LinalgOnTensorsOnnxBackend(OnnxBackend):
"Lowering TorchFX IR -> Torch Backend IR",
)
imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
compiled_module = self.refbackend.compile(imported_module)
return compiled_module

View File

@ -43,6 +43,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
compiler_utils.py
fx.py
extras/fx_decomp_util.py
)

View File

@ -0,0 +1,166 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from enum import Enum
from io import StringIO
import os
import sys
import tempfile
from typing import Union
from torch_mlir.passmanager import PassManager
from torch_mlir.ir import StringAttr
def get_module_name_for_debug_dump(module):
"""Gets a name suitable for a debug dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
class TorchMlirCompilerError(Exception):
pass
def run_pipeline_with_repro_report(module,
pipeline: str,
description: str,
enable_ir_printing: bool = False):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
original_stderr = sys.stderr
try:
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
# Lower module in place to make it ready for compiler backends.
with module.context as ctx:
pm = PassManager.parse(pipeline)
if enable_ir_printing:
ctx.enable_multithreading(False)
pm.enable_ir_printing()
pm.run(module.operation)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm_for_error_report)
debug_options="-mlir-print-ir-after-all -mlir-disable-threading"
# Put something descriptive here even if description is empty.
description = description or f"{module_name} compile"
message = f"""\
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
python exception: {e}
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
Add '{debug_options}' to get the IR dump for debugging purpose.
"""
trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')])
raise TorchMlirCompilerError(trimmed_message) from None
finally:
sys.stderr = original_stderr
class OutputType(Enum):
# Output torch dialect. When converting from FX, this will be immediately
# after the import from FX to MLIR. When converting from torchscript,
# this will come after some cleanup passes which attempt to de-alias,
# decompose and infer shapes. These should be roughly the same level of
# abstraction since those steps are done within PyTorch itself
# when coming directly from Dynamo/FX.
TORCH = "torch"
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = "linalg-on-tensors"
# This output type consists of `tosa` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa"
# This output type consists of `stablehlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = "raw"
@staticmethod
def get(spec: Union[str, "OutputType"]) -> "OutputType":
"""Gets an OutputType from allowed way to specify one.
Args:
spec: An OutputType instance or the case-insensitive name of one of the
enum values.
Returns:
An OutputType instance.
"""
if isinstance(spec, OutputType):
return spec
spec = spec.upper().replace("-", "_")
if spec not in OutputType.__members__:
raise ValueError(f"For output_type= argument, expected one of: "
f"{', '.join(OutputType.__members__.keys())}")
return OutputType[spec]
def lower_mlir_module(verbose, output_type, module):
if verbose:
print("\n====================")
print("Torch Backend IR")
print(module)
if output_type == OutputType.TORCH:
return module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
"Lowering Torch Backend IR -> TOSA Backend IR")
if verbose:
print("\n====================")
print("TOSA Backend IR")
print(module)
return module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
if verbose:
print("\n====================")
print("LINALG Backend IR")
print(module)
return module
elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("StableHLO Backend IR")
print(module)
return module
raise Exception(f"Unknown OutputType: {output_type}")