mirror of https://github.com/llvm/torch-mlir
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
from typing import Optional
|
|
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.export
|
|
import torch.nn as nn
|
|
|
|
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
|
from torch_mlir import ir
|
|
from torch_mlir.dialects import torch as torch_d
|
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
|
|
|
def export_and_import(
|
|
f,
|
|
*args,
|
|
fx_importer: Optional[FxImporter] = None,
|
|
constraints: Optional[torch.export.Constraint] = None,
|
|
experimental_support_mutation: bool = False,
|
|
hooks: Optional[FxImporterHooks] = None,
|
|
**kwargs,
|
|
):
|
|
context = ir.Context()
|
|
torch_d.register_dialect(context)
|
|
|
|
if fx_importer is None:
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
|
prog = torch.export.export(f, args, kwargs, constraints=constraints)
|
|
decomp_table = get_decomposition_table()
|
|
prog = prog.run_decompositions(decomp_table)
|
|
if experimental_support_mutation:
|
|
if torch.__version__ < "2.3.0.dev20240207":
|
|
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
|
fx_importer.import_program(prog)
|
|
else:
|
|
fx_importer.import_frozen_program(prog)
|
|
|
|
return fx_importer.module_op
|