[FxImporter] Added FxImporter test method to be executed via torch.co… (#3795)

pull/3805/head
penguin_wwy 2024-10-16 10:32:52 +08:00 committed by GitHub
parent 45bb17ebfe
commit 6b289f29f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 78 additions and 2 deletions

View File

@ -8,6 +8,7 @@ import torch
import torch.utils._pytree as pytree
from torch.export.graph_signature import OutputSpec, OutputKind
from torch.export import ExportedProgram
from torch._dynamo.backends.common import aot_autograd
from torch_mlir import fx
from torch_mlir_e2e_test.configs.utils import (
@ -15,6 +16,7 @@ from torch_mlir_e2e_test.configs.utils import (
recursively_convert_from_numpy,
)
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
def refine_result_type(_result):
@ -31,9 +33,10 @@ def refine_result_type(_result):
class FxImporterTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
def __init__(self, backend, output_type="linalg-on-tensors"):
def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False):
super().__init__()
self._backend = backend
self._torch_compile = torch_compile
self._output_type = output_type
def compile(
@ -41,7 +44,80 @@ class FxImporterTestConfig(TestConfig):
) -> torch.nn.Module:
return program
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
def run(self, artifact: torch.nn.Module, trace: Trace):
return (
self._export_run(artifact, trace)
if not self._torch_compile
else self._stateless_run(artifact, trace)
)
def _stateless_run(self, artifact: torch.nn.Module, trace: Trace):
dynamic_argument_pos = None
dynamic_dim_pos = None
annotations = getattr(artifact.forward, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
for i, annotation in enumerate(annotations):
if i == 0: # Skip the "self" annotation.
continue
if not annotation[2]:
raise ValueError(
"Can only compile inputs annotated as having value semantics."
)
for dim_i, dim in enumerate(annotation[0]):
if dim == -1:
dynamic_argument_pos = i - 1
dynamic_dim_pos = dim_i
break
if dynamic_argument_pos is not None:
break
result: Trace = []
for item in trace:
def _base_backend(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.op == "placeholder":
if (
isinstance(node.meta["val"], torch.SymInt)
and not node.users
):
gm.graph.erase_node(node)
module = fx.stateless_fx_import(
gm,
output_type=self._output_type,
model_name=artifact.__class__.__name__,
)
module = self._backend.compile(module)
backend_module = self._backend.load(module)
def invoke_func(*torch_inputs):
torch_inputs = [
x
for x in filter(
lambda i: isinstance(i, torch.Tensor), torch_inputs
)
]
with torch.no_grad():
numpy_inputs = recursively_convert_to_numpy(torch_inputs)
return recursively_convert_from_numpy(
getattr(backend_module, artifact.__class__.__name__)(
*numpy_inputs
)
)
return invoke_func
fw_compiler = aot_autograd(fw_compiler=_base_backend)
if dynamic_argument_pos is not None:
torch._dynamo.mark_dynamic(
item.inputs[dynamic_argument_pos], dynamic_dim_pos
)
module = torch.compile(artifact, backend=fw_compiler)
outputs = module(*item.inputs)
result.append(
TraceItem(symbol=item.symbol, inputs=item.inputs, output=outputs)
)
return result
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))