mirror of https://github.com/llvm/torch-mlir
124 lines
4.2 KiB
Python
124 lines
4.2 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
|
from packaging import version
|
|
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.export
|
|
import torch.nn as nn
|
|
from torch.export import ExportedProgram
|
|
|
|
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
|
from torch_mlir import ir
|
|
from torch_mlir.dialects import torch as torch_d
|
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
|
from torch_mlir.compiler_utils import (
|
|
OutputType,
|
|
run_pipeline_with_repro_report,
|
|
lower_mlir_module,
|
|
)
|
|
|
|
|
|
def _module_lowering(
|
|
verbose,
|
|
output_type,
|
|
torch_mod,
|
|
extra_library_file_name=None,
|
|
):
|
|
|
|
if output_type == OutputType.RAW:
|
|
if verbose:
|
|
print(torch_mod)
|
|
return torch_mod
|
|
# TODO: pass extra_library_file_name by caller
|
|
if extra_library_file_name is None:
|
|
extra_library_file_name = ""
|
|
option_string = "{extra-library=" + extra_library_file_name + "}"
|
|
run_pipeline_with_repro_report(
|
|
torch_mod,
|
|
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
|
|
"Lowering TorchFX IR -> Torch Backend IR",
|
|
enable_ir_printing=verbose,
|
|
)
|
|
return lower_mlir_module(verbose, output_type, torch_mod)
|
|
|
|
|
|
def export_and_import(
|
|
f: Union[nn.Module, ExportedProgram],
|
|
*args,
|
|
output_type: Union[str, OutputType] = OutputType.RAW,
|
|
fx_importer: Optional[FxImporter] = None,
|
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
experimental_support_mutation: bool = False,
|
|
import_symbolic_shape_expressions: bool = False,
|
|
hooks: Optional[FxImporterHooks] = None,
|
|
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
|
func_name: str = "main",
|
|
enable_graph_printing: bool = False,
|
|
enable_ir_printing: bool = False,
|
|
**kwargs,
|
|
):
|
|
context = ir.Context()
|
|
torch_d.register_dialect(context)
|
|
|
|
if fx_importer is None:
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
|
if isinstance(f, ExportedProgram):
|
|
prog = f
|
|
else:
|
|
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
|
|
if version.Version(torch.__version__) >= version.Version("2.2.0"):
|
|
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
|
|
else:
|
|
prog = torch.export.export(f, args, kwargs)
|
|
if decomposition_table is None:
|
|
decomposition_table = get_decomposition_table()
|
|
if decomposition_table:
|
|
prog = prog.run_decompositions(decomposition_table)
|
|
if enable_graph_printing:
|
|
prog.graph_module.print_readable()
|
|
if experimental_support_mutation:
|
|
if torch.__version__ < "2.3.0.dev20240207":
|
|
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
|
fx_importer.import_program(
|
|
prog,
|
|
func_name=func_name,
|
|
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
|
)
|
|
else:
|
|
fx_importer.import_frozen_program(
|
|
prog,
|
|
func_name=func_name,
|
|
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
|
)
|
|
|
|
return _module_lowering(
|
|
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
|
)
|
|
|
|
|
|
def stateless_fx_import(
|
|
gm: torch.fx.GraphModule,
|
|
output_type: Union[str, OutputType] = OutputType.RAW,
|
|
fx_importer: Optional[FxImporter] = None,
|
|
hooks: Optional[FxImporterHooks] = None,
|
|
model_name: str = "main",
|
|
enable_graph_printing: bool = False,
|
|
enable_ir_printing: bool = False,
|
|
):
|
|
if enable_graph_printing:
|
|
gm.print_readable()
|
|
context = ir.Context()
|
|
torch_d.register_dialect(context)
|
|
if fx_importer is None:
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
|
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
|
return _module_lowering(
|
|
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
|
)
|