mirror of https://github.com/llvm/torch-mlir
[FxImporter] Added FxImporter test method to be executed via torch.co… (#3795)
parent
45bb17ebfe
commit
6b289f29f2
|
@ -8,6 +8,7 @@ import torch
|
|||
import torch.utils._pytree as pytree
|
||||
from torch.export.graph_signature import OutputSpec, OutputKind
|
||||
from torch.export import ExportedProgram
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
|
||||
from torch_mlir import fx
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
|
@ -15,6 +16,7 @@ from torch_mlir_e2e_test.configs.utils import (
|
|||
recursively_convert_from_numpy,
|
||||
)
|
||||
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||
|
||||
|
||||
def refine_result_type(_result):
|
||||
|
@ -31,9 +33,10 @@ def refine_result_type(_result):
|
|||
class FxImporterTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
|
||||
|
||||
def __init__(self, backend, output_type="linalg-on-tensors"):
|
||||
def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False):
|
||||
super().__init__()
|
||||
self._backend = backend
|
||||
self._torch_compile = torch_compile
|
||||
self._output_type = output_type
|
||||
|
||||
def compile(
|
||||
|
@ -41,7 +44,80 @@ class FxImporterTestConfig(TestConfig):
|
|||
) -> torch.nn.Module:
|
||||
return program
|
||||
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
def run(self, artifact: torch.nn.Module, trace: Trace):
|
||||
return (
|
||||
self._export_run(artifact, trace)
|
||||
if not self._torch_compile
|
||||
else self._stateless_run(artifact, trace)
|
||||
)
|
||||
|
||||
def _stateless_run(self, artifact: torch.nn.Module, trace: Trace):
|
||||
dynamic_argument_pos = None
|
||||
dynamic_dim_pos = None
|
||||
annotations = getattr(artifact.forward, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
|
||||
for i, annotation in enumerate(annotations):
|
||||
if i == 0: # Skip the "self" annotation.
|
||||
continue
|
||||
if not annotation[2]:
|
||||
raise ValueError(
|
||||
"Can only compile inputs annotated as having value semantics."
|
||||
)
|
||||
for dim_i, dim in enumerate(annotation[0]):
|
||||
if dim == -1:
|
||||
dynamic_argument_pos = i - 1
|
||||
dynamic_dim_pos = dim_i
|
||||
break
|
||||
if dynamic_argument_pos is not None:
|
||||
break
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
|
||||
def _base_backend(gm: torch.fx.GraphModule, example_inputs):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
if (
|
||||
isinstance(node.meta["val"], torch.SymInt)
|
||||
and not node.users
|
||||
):
|
||||
gm.graph.erase_node(node)
|
||||
module = fx.stateless_fx_import(
|
||||
gm,
|
||||
output_type=self._output_type,
|
||||
model_name=artifact.__class__.__name__,
|
||||
)
|
||||
module = self._backend.compile(module)
|
||||
backend_module = self._backend.load(module)
|
||||
|
||||
def invoke_func(*torch_inputs):
|
||||
torch_inputs = [
|
||||
x
|
||||
for x in filter(
|
||||
lambda i: isinstance(i, torch.Tensor), torch_inputs
|
||||
)
|
||||
]
|
||||
with torch.no_grad():
|
||||
numpy_inputs = recursively_convert_to_numpy(torch_inputs)
|
||||
return recursively_convert_from_numpy(
|
||||
getattr(backend_module, artifact.__class__.__name__)(
|
||||
*numpy_inputs
|
||||
)
|
||||
)
|
||||
|
||||
return invoke_func
|
||||
|
||||
fw_compiler = aot_autograd(fw_compiler=_base_backend)
|
||||
if dynamic_argument_pos is not None:
|
||||
torch._dynamo.mark_dynamic(
|
||||
item.inputs[dynamic_argument_pos], dynamic_dim_pos
|
||||
)
|
||||
module = torch.compile(artifact, backend=fw_compiler)
|
||||
outputs = module(*item.inputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol, inputs=item.inputs, output=outputs)
|
||||
)
|
||||
return result
|
||||
|
||||
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
|
||||
|
|
Loading…
Reference in New Issue