[FxImporter] Add backend lowering to Fx API (#3288)

pull/3296/head
penguin_wwy 2024-05-07 20:58:50 +08:00 committed by GitHub
parent 6f911ba3d7
commit c3bd850951
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 63 deletions

View File

@ -3,8 +3,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE. # Also available under a BSD-style license. See LICENSE.
from typing import Union, Optional, Sequence
import numpy as np import numpy as np
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -12,15 +10,6 @@ from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram from torch.export import ExportedProgram
from torch_mlir import fx from torch_mlir import fx
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
lower_mlir_module,
OutputType,
)
from torch_mlir.torchscript import (
BACKEND_LEGAL_OPS,
_canon_extra_library,
)
from torch_mlir_e2e_test.configs.utils import ( from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy, recursively_convert_to_numpy,
recursively_convert_from_numpy, recursively_convert_from_numpy,
@ -39,53 +28,6 @@ def refine_result_type(_result):
raise ValueError(f"Unhandled return type {type(_result)}") raise ValueError(f"Unhandled return type {type(_result)}")
def jit(
prog: ExportedProgram,
func_name: str,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
verbose: bool = False,
):
if extra_library is None:
extra_library = []
mlir_module = None
extra_library_file_name = _canon_extra_library(extra_library)
output_type = OutputType.get(output_type)
if backend_legal_ops is not None:
if output_type != OutputType.TORCH:
raise Exception(
"`backend_legal_ops` is only valid with the " "`torch` output type"
)
backend_legal_ops = list(sorted(set(backend_legal_ops)))
else:
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)
mlir_module = fx.export_and_import(prog, func_name=func_name)
assert mlir_module is not None
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-simplification-pipeline)",
"Simplification pipeline for torch dialect",
)
run_pipeline_with_repro_report(
mlir_module,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
)
return lower_mlir_module(verbose, output_type, mlir_module)
class FxImporterTestConfig(TestConfig): class FxImporterTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with Fx Importer""" """TestConfig that runs the torch.nn.Module with Fx Importer"""
@ -100,11 +42,11 @@ class FxImporterTestConfig(TestConfig):
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = [] result: Trace = []
for item in trace: for item in trace:
prog = torch.export.export(artifact, tuple(item.inputs)) prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
module = jit( module = fx.export_and_import(
prog, prog,
func_name=artifact.__class__.__name__,
output_type=self._output_type, output_type=self._output_type,
func_name=artifact.__class__.__name__,
) )
module = self._backend.compile(module) module = self._backend.compile(module)
backend_module = self._backend.load(module) backend_module = self._backend.load(module)

View File

@ -16,11 +16,50 @@ from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
from torch_mlir import ir from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d from torch_mlir.dialects import torch as torch_d
from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_decomp_util import get_decomposition_table
from torch_mlir.compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
)
def _module_lowering(
verbose,
output_type,
torch_mod,
backend_legal_ops=None,
extra_library_file_name=None,
):
if output_type == OutputType.TORCH:
if verbose:
print(torch_mod)
return torch_mod
# TODO: pass backend_legal_ops/extra_library_file_name by caller
if backend_legal_ops is None:
backend_legal_ops = []
if extra_library_file_name is None:
extra_library_file_name = ""
option_string = (
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ "}"
)
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
return lower_mlir_module(verbose, output_type, torch_mod)
def export_and_import( def export_and_import(
f: Union[nn.Module, ExportedProgram], f: Union[nn.Module, ExportedProgram],
*args, *args,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None, fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False, experimental_support_mutation: bool = False,
@ -28,6 +67,7 @@ def export_and_import(
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
func_name: str = "main", func_name: str = "main",
enable_graph_printing: bool = False, enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
**kwargs, **kwargs,
): ):
context = ir.Context() context = ir.Context()
@ -52,15 +92,19 @@ def export_and_import(
else: else:
fx_importer.import_frozen_program(prog, func_name=func_name) fx_importer.import_frozen_program(prog, func_name=func_name)
return fx_importer.module return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)
def stateless_fx_import( def stateless_fx_import(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
output_type: Union[str, OutputType] = OutputType.TORCH,
fx_importer: Optional[FxImporter] = None, fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None, hooks: Optional[FxImporterHooks] = None,
model_name: str = "main", model_name: str = "main",
enable_graph_printing: bool = False, enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
): ):
if enable_graph_printing: if enable_graph_printing:
gm.print_readable() gm.print_readable()
@ -69,4 +113,6 @@ def stateless_fx_import(
if fx_importer is None: if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks) fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name) fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return fx_importer.module return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
)