mirror of https://github.com/llvm/torch-mlir
[FxImporter] Add backend lowering to Fx API (#3288)
parent
6f911ba3d7
commit
c3bd850951
|
@ -3,8 +3,6 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
from typing import Union, Optional, Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
@ -12,15 +10,6 @@ from torch.export.graph_signature import OutputSpec, OutputKind
|
||||||
from torch.export import ExportedProgram
|
from torch.export import ExportedProgram
|
||||||
|
|
||||||
from torch_mlir import fx
|
from torch_mlir import fx
|
||||||
from torch_mlir.compiler_utils import (
|
|
||||||
run_pipeline_with_repro_report,
|
|
||||||
lower_mlir_module,
|
|
||||||
OutputType,
|
|
||||||
)
|
|
||||||
from torch_mlir.torchscript import (
|
|
||||||
BACKEND_LEGAL_OPS,
|
|
||||||
_canon_extra_library,
|
|
||||||
)
|
|
||||||
from torch_mlir_e2e_test.configs.utils import (
|
from torch_mlir_e2e_test.configs.utils import (
|
||||||
recursively_convert_to_numpy,
|
recursively_convert_to_numpy,
|
||||||
recursively_convert_from_numpy,
|
recursively_convert_from_numpy,
|
||||||
|
@ -39,53 +28,6 @@ def refine_result_type(_result):
|
||||||
raise ValueError(f"Unhandled return type {type(_result)}")
|
raise ValueError(f"Unhandled return type {type(_result)}")
|
||||||
|
|
||||||
|
|
||||||
def jit(
|
|
||||||
prog: ExportedProgram,
|
|
||||||
func_name: str,
|
|
||||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
|
||||||
backend_legal_ops: Optional[Sequence[str]] = None,
|
|
||||||
extra_library=None,
|
|
||||||
verbose: bool = False,
|
|
||||||
):
|
|
||||||
if extra_library is None:
|
|
||||||
extra_library = []
|
|
||||||
mlir_module = None
|
|
||||||
|
|
||||||
extra_library_file_name = _canon_extra_library(extra_library)
|
|
||||||
output_type = OutputType.get(output_type)
|
|
||||||
if backend_legal_ops is not None:
|
|
||||||
if output_type != OutputType.TORCH:
|
|
||||||
raise Exception(
|
|
||||||
"`backend_legal_ops` is only valid with the " "`torch` output type"
|
|
||||||
)
|
|
||||||
backend_legal_ops = list(sorted(set(backend_legal_ops)))
|
|
||||||
else:
|
|
||||||
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
|
|
||||||
|
|
||||||
option_string = (
|
|
||||||
"{backend-legal-ops="
|
|
||||||
+ ",".join(backend_legal_ops)
|
|
||||||
+ " extra-library="
|
|
||||||
+ extra_library_file_name
|
|
||||||
+ "}"
|
|
||||||
)
|
|
||||||
|
|
||||||
mlir_module = fx.export_and_import(prog, func_name=func_name)
|
|
||||||
assert mlir_module is not None
|
|
||||||
run_pipeline_with_repro_report(
|
|
||||||
mlir_module,
|
|
||||||
f"builtin.module(torch-simplification-pipeline)",
|
|
||||||
"Simplification pipeline for torch dialect",
|
|
||||||
)
|
|
||||||
run_pipeline_with_repro_report(
|
|
||||||
mlir_module,
|
|
||||||
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
|
||||||
"Lowering TorchFX IR -> Torch Backend IR",
|
|
||||||
)
|
|
||||||
|
|
||||||
return lower_mlir_module(verbose, output_type, mlir_module)
|
|
||||||
|
|
||||||
|
|
||||||
class FxImporterTestConfig(TestConfig):
|
class FxImporterTestConfig(TestConfig):
|
||||||
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
|
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
|
||||||
|
|
||||||
|
@ -100,11 +42,11 @@ class FxImporterTestConfig(TestConfig):
|
||||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||||
result: Trace = []
|
result: Trace = []
|
||||||
for item in trace:
|
for item in trace:
|
||||||
prog = torch.export.export(artifact, tuple(item.inputs))
|
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
|
||||||
module = jit(
|
module = fx.export_and_import(
|
||||||
prog,
|
prog,
|
||||||
func_name=artifact.__class__.__name__,
|
|
||||||
output_type=self._output_type,
|
output_type=self._output_type,
|
||||||
|
func_name=artifact.__class__.__name__,
|
||||||
)
|
)
|
||||||
module = self._backend.compile(module)
|
module = self._backend.compile(module)
|
||||||
backend_module = self._backend.load(module)
|
backend_module = self._backend.load(module)
|
||||||
|
|
|
@ -16,11 +16,50 @@ from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
||||||
from torch_mlir import ir
|
from torch_mlir import ir
|
||||||
from torch_mlir.dialects import torch as torch_d
|
from torch_mlir.dialects import torch as torch_d
|
||||||
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
||||||
|
from torch_mlir.compiler_utils import (
|
||||||
|
OutputType,
|
||||||
|
run_pipeline_with_repro_report,
|
||||||
|
lower_mlir_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _module_lowering(
|
||||||
|
verbose,
|
||||||
|
output_type,
|
||||||
|
torch_mod,
|
||||||
|
backend_legal_ops=None,
|
||||||
|
extra_library_file_name=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
if output_type == OutputType.TORCH:
|
||||||
|
if verbose:
|
||||||
|
print(torch_mod)
|
||||||
|
return torch_mod
|
||||||
|
# TODO: pass backend_legal_ops/extra_library_file_name by caller
|
||||||
|
if backend_legal_ops is None:
|
||||||
|
backend_legal_ops = []
|
||||||
|
if extra_library_file_name is None:
|
||||||
|
extra_library_file_name = ""
|
||||||
|
option_string = (
|
||||||
|
"{backend-legal-ops="
|
||||||
|
+ ",".join(backend_legal_ops)
|
||||||
|
+ " extra-library="
|
||||||
|
+ extra_library_file_name
|
||||||
|
+ "}"
|
||||||
|
)
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
torch_mod,
|
||||||
|
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
||||||
|
"Lowering TorchFX IR -> Torch Backend IR",
|
||||||
|
enable_ir_printing=verbose,
|
||||||
|
)
|
||||||
|
return lower_mlir_module(verbose, output_type, torch_mod)
|
||||||
|
|
||||||
|
|
||||||
def export_and_import(
|
def export_and_import(
|
||||||
f: Union[nn.Module, ExportedProgram],
|
f: Union[nn.Module, ExportedProgram],
|
||||||
*args,
|
*args,
|
||||||
|
output_type: Union[str, OutputType] = OutputType.TORCH,
|
||||||
fx_importer: Optional[FxImporter] = None,
|
fx_importer: Optional[FxImporter] = None,
|
||||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||||
experimental_support_mutation: bool = False,
|
experimental_support_mutation: bool = False,
|
||||||
|
@ -28,6 +67,7 @@ def export_and_import(
|
||||||
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
||||||
func_name: str = "main",
|
func_name: str = "main",
|
||||||
enable_graph_printing: bool = False,
|
enable_graph_printing: bool = False,
|
||||||
|
enable_ir_printing: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
context = ir.Context()
|
context = ir.Context()
|
||||||
|
@ -52,15 +92,19 @@ def export_and_import(
|
||||||
else:
|
else:
|
||||||
fx_importer.import_frozen_program(prog, func_name=func_name)
|
fx_importer.import_frozen_program(prog, func_name=func_name)
|
||||||
|
|
||||||
return fx_importer.module
|
return _module_lowering(
|
||||||
|
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def stateless_fx_import(
|
def stateless_fx_import(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
|
output_type: Union[str, OutputType] = OutputType.TORCH,
|
||||||
fx_importer: Optional[FxImporter] = None,
|
fx_importer: Optional[FxImporter] = None,
|
||||||
hooks: Optional[FxImporterHooks] = None,
|
hooks: Optional[FxImporterHooks] = None,
|
||||||
model_name: str = "main",
|
model_name: str = "main",
|
||||||
enable_graph_printing: bool = False,
|
enable_graph_printing: bool = False,
|
||||||
|
enable_ir_printing: bool = False,
|
||||||
):
|
):
|
||||||
if enable_graph_printing:
|
if enable_graph_printing:
|
||||||
gm.print_readable()
|
gm.print_readable()
|
||||||
|
@ -69,4 +113,6 @@ def stateless_fx_import(
|
||||||
if fx_importer is None:
|
if fx_importer is None:
|
||||||
fx_importer = FxImporter(context=context, hooks=hooks)
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
||||||
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
||||||
return fx_importer.module
|
return _module_lowering(
|
||||||
|
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue