mirror of https://github.com/llvm/torch-mlir
Add stateless fx graph import (#3036)
parent
3a56714bff
commit
7616d637fd
|
@ -22,6 +22,7 @@ import pdb
|
|||
# introducing new concepts or abstractions into the import process.
|
||||
|
||||
from typing import Dict, Tuple
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import operator
|
||||
import re
|
||||
|
@ -426,6 +427,7 @@ class _FXGraphImporter:
|
|||
raise Exception(f"Unsupported literal: {arg}")
|
||||
|
||||
|
||||
@deprecated("Please use fx importer as a replacement to support torchdynamo")
|
||||
def import_fx_graph_as_func(g: torch.fx.Graph, func_name: str) -> ir.Module:
|
||||
"""Imports the given FX graph as a function in a new MLIR module.
|
||||
|
||||
|
|
|
@ -41,4 +41,18 @@ def export_and_import(
|
|||
else:
|
||||
fx_importer.import_frozen_program(prog, func_name=func_name)
|
||||
|
||||
return fx_importer.module_op
|
||||
return fx_importer.module
|
||||
|
||||
|
||||
def stateless_fx_import(
|
||||
gm: torch.fx.GraphModule,
|
||||
fx_importer: Optional[FxImporter] = None,
|
||||
hooks: Optional[FxImporterHooks] = None,
|
||||
model_name: str = "main",
|
||||
):
|
||||
context = ir.Context()
|
||||
torch_d.register_dialect(context)
|
||||
if fx_importer is None:
|
||||
fx_importer = FxImporter(context=context, hooks=hooks)
|
||||
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
||||
return fx_importer.module
|
||||
|
|
|
@ -5,11 +5,13 @@
|
|||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from typing import Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.export import Dim
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name
|
||||
|
||||
from torch_mlir import fx
|
||||
|
||||
|
@ -93,3 +95,26 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
|||
dynamic_shapes = {"x": {0: batch}}
|
||||
m = fx.export_and_import(Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net")
|
||||
print(m)
|
||||
|
||||
|
||||
|
||||
@make_boxed_compiler
|
||||
def fx_import_aot_autograd_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
print(gm.print_readable(False), flush=True)
|
||||
m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name())
|
||||
print(m, flush=True)
|
||||
return gm
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_stateless_fx_import
|
||||
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
|
||||
def test_stateless_fx_import():
|
||||
fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend)
|
||||
set_model_name("basic_forward")
|
||||
@torch._dynamo.optimize(backend=fx_import_backend)
|
||||
def basic_forward(x):
|
||||
return torch.tanh(x)
|
||||
|
||||
basic_forward(torch.randn(3, 4))
|
||||
|
|
Loading…
Reference in New Issue