Add stateless fx graph import (#3036)

pull/3047/head
penguin_wwy 2024-03-22 05:44:54 +08:00 committed by GitHub
parent 3a56714bff
commit 7616d637fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 2 deletions

View File

@ -22,6 +22,7 @@ import pdb
# introducing new concepts or abstractions into the import process.
from typing import Dict, Tuple
from typing_extensions import deprecated
import operator
import re
@ -426,6 +427,7 @@ class _FXGraphImporter:
raise Exception(f"Unsupported literal: {arg}")
@deprecated("Please use fx importer as a replacement to support torchdynamo")
def import_fx_graph_as_func(g: torch.fx.Graph, func_name: str) -> ir.Module:
"""Imports the given FX graph as a function in a new MLIR module.

View File

@ -41,4 +41,18 @@ def export_and_import(
else:
fx_importer.import_frozen_program(prog, func_name=func_name)
return fx_importer.module_op
return fx_importer.module
def stateless_fx_import(
gm: torch.fx.GraphModule,
fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None,
model_name: str = "main",
):
context = ir.Context()
torch_d.register_dialect(context)
if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return fx_importer.module

View File

@ -5,11 +5,13 @@
# RUN: %PYTHON %s | FileCheck %s
from typing import Optional
from typing import List
import torch
import torch.nn as nn
from torch.export import Dim
from torch._dynamo.backends.common import aot_autograd
from torch._functorch.aot_autograd import make_boxed_compiler, get_aot_graph_name, set_model_name
from torch_mlir import fx
@ -93,3 +95,26 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net")
print(m)
@make_boxed_compiler
def fx_import_aot_autograd_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print(gm.print_readable(False), flush=True)
m = fx.stateless_fx_import(gm, model_name=get_aot_graph_name())
print(m, flush=True)
return gm
@run
# CHECK-LABEL: test_stateless_fx_import
# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
# CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32>
def test_stateless_fx_import():
fx_import_backend = aot_autograd(fw_compiler=fx_import_aot_autograd_backend)
set_model_name("basic_forward")
@torch._dynamo.optimize(backend=fx_import_backend)
def basic_forward(x):
return torch.tanh(x)
basic_forward(torch.randn(3, 4))