mirror of https://github.com/llvm/torch-mlir
[FxImporter] Add backend lowering to Fx API (#3288)
parent
6f911ba3d7
commit
c3bd850951
|
@ -3,8 +3,6 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import Union, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
@ -12,15 +10,6 @@ from torch.export.graph_signature import OutputSpec, OutputKind
|
|||
from torch.export import ExportedProgram
|
||||
|
||||
from torch_mlir import fx
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
lower_mlir_module,
|
||||
OutputType,
|
||||
)
|
||||
from torch_mlir.torchscript import (
|
||||
BACKEND_LEGAL_OPS,
|
||||
_canon_extra_library,
|
||||
)
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
|
@ -39,53 +28,6 @@ def refine_result_type(_result):
|
|||
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||
|
||||
|
||||
def jit(
|
||||
prog: ExportedProgram,
|
||||
func_name: str,
|
||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
||||
extra_library=None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
if extra_library is None:
|
||||
extra_library = []
|
||||
mlir_module = None
|
||||
|
||||
extra_library_file_name = _canon_extra_library(extra_library)
|
||||
output_type = OutputType.get(output_type)
|
||||
if backend_legal_ops is not None:
|
||||
if output_type != OutputType.TORCH:
|
||||
raise Exception(
|
||||
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
||||
)
|
||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
||||
else:
|
||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
||||
|
||||
option_string = (
|
||||
"{backend-legal-ops="
|
||||
+ ",".join(backend_legal_ops)
|
||||
+ " extra-library="
|
||||
+ extra_library_file_name
|
||||
+ "}"
|
||||
)
|
||||
|
||||
mlir_module = fx.export_and_import(prog, func_name=func_name)
|
||||
assert mlir_module is not None
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
f"builtin.module(torch-simplification-pipeline)",
|
||||
"Simplification pipeline for torch dialect",
|
||||
)
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
||||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
)
|
||||
|
||||
return lower_mlir_module(verbose, output_type, mlir_module)
|
||||
|
||||
|
||||
class FxImporterTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
|
||||
|
||||
|
@ -100,11 +42,11 @@ class FxImporterTestConfig(TestConfig):
|
|||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
prog = torch.export.export(artifact, tuple(item.inputs))
|
||||
module = jit(
|
||||
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
|
||||
module = fx.export_and_import(
|
||||
prog,
|
||||
func_name=artifact.__class__.__name__,
|
||||
output_type=self._output_type,
|
||||
func_name=artifact.__class__.__name__,
|
||||
)
|
||||
module = self._backend.compile(module)
|
||||
backend_module = self._backend.load(module)
|
||||
|
|
|
@ -16,11 +16,50 @@ from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
|||
from torch_mlir import ir
|
||||
from torch_mlir.dialects import torch as torch_d
|
||||
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
||||
from torch_mlir.compiler_utils import (
|
||||
OutputType,
|
||||
run_pipeline_with_repro_report,
|
||||
lower_mlir_module,
|
||||
)
|
||||
|
||||
|
||||
def _module_lowering(
|
||||
verbose,
|
||||
output_type,
|
||||
torch_mod,
|
||||
backend_legal_ops=None,
|
||||
extra_library_file_name=None,
|
||||
):
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
if verbose:
|
||||
print(torch_mod)
|
||||
return torch_mod
|
||||
# TODO: pass backend_legal_ops/extra_library_file_name by caller
|
||||
if backend_legal_ops is None:
|
||||
backend_legal_ops = []
|
||||
if extra_library_file_name is None:
|
||||
extra_library_file_name = ""
|
||||
option_string = (
|
||||
"{backend-legal-ops="
|
||||
+ ",".join(backend_legal_ops)
|
||||
+ " extra-library="
|
||||
+ extra_library_file_name
|
||||
+ "}"
|
||||
)
|
||||
run_pipeline_with_repro_report(
|
||||
torch_mod,
|
||||
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
||||
"Lowering TorchFX IR -> Torch Backend IR",
|
||||
enable_ir_printing=verbose,
|
||||
)
|
||||
return lower_mlir_module(verbose, output_type, torch_mod)
|
||||
|
||||
|
||||
def export_and_import(
|
||||
f: Union[nn.Module, ExportedProgram],
|
||||
*args,
|
||||
output_type: Union[str, OutputType] = OutputType.TORCH,
|
||||
fx_importer: Optional[FxImporter] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
experimental_support_mutation: bool = False,
|
||||
|
@ -28,6 +67,7 @@ def export_and_import(
|
|||
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||
func_name: str = "main",
|
||||
enable_graph_printing: bool = False,
|
||||
enable_ir_printing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
context = ir.Context()
|
||||
|
@ -52,15 +92,19 @@ def export_and_import(
|
|||
else:
|
||||
fx_importer.import_frozen_program(prog, func_name=func_name)
|
||||
|
||||
return fx_importer.module
|
||||
return _module_lowering(
|
||||
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
||||
)
|
||||
|
||||
|
||||
def stateless_fx_import(
|
||||
gm: torch.fx.GraphModule,
|
||||
output_type: Union[str, OutputType] = OutputType.TORCH,
|
||||
fx_importer: Optional[FxImporter] = None,
|
||||
hooks: Optional[FxImporterHooks] = None,
|
||||
model_name: str = "main",
|
||||
enable_graph_printing: bool = False,
|
||||
enable_ir_printing: bool = False,
|
||||
):
|
||||
if enable_graph_printing:
|
||||
gm.print_readable()
|
||||
|
@ -69,4 +113,6 @@ def stateless_fx_import(
|
|||
if fx_importer is None:
|
||||
fx_importer = FxImporter(context=context, hooks=hooks)
|
||||
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
||||
return fx_importer.module
|
||||
return _module_lowering(
|
||||
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue