[fx] Accept `func_visibility=` and return created func op. (#3054)

This is a partial landing of #3046 while waiting for an upstream change
for the rest of it.
pull/3061/head
Stella Laurenzo 2024-03-25 16:48:06 -07:00 committed by GitHub
parent 9ae33e482e
commit 17eeac880a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 33 additions and 10 deletions

View File

@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
if sparsity.layout is torch.sparse_coo:
assert sparse_dim >= 2 and blocksize is None
trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim))
coords = ",".join(
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
)
sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
@ -467,8 +469,12 @@ class FxImporter:
return self._m.operation
def import_program(
self, prog: torch.export.ExportedProgram, *, func_name: str = "main"
):
self,
prog: torch.export.ExportedProgram,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Imports an ExportedProgram according to our chosen canonical representation.
This mechanism is the fully general solution for handling an ExportedProgram
@ -628,7 +634,9 @@ class FxImporter:
# Create the function.
with loc:
func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip)
func_op = func_dialect.FuncOp(
func_name, ftype, ip=self._m_ip, visibility=func_visibility
)
entry_block = Block.create_at_start(func_op.body, ftype.inputs)
node_importer = GraphNodeImporter(
@ -668,10 +676,15 @@ class FxImporter:
)
node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op)
return func_op
def import_frozen_program(
self, prog: torch.export.ExportedProgram, func_name: str = "main"
):
self,
prog: torch.export.ExportedProgram,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Imports a consolidated torch.export.ExportedProgram instance.
If using the new torch.export path (vs a lower level precursor), then this is
@ -750,17 +763,25 @@ class FxImporter:
node.replace_all_uses_with(replacement)
g.erase_node(node)
self.import_stateless_graph(g, func_name)
return self.import_stateless_graph(
g, func_name=func_name, func_visibility=func_visibility
)
def import_graph_module(self, gm: GraphModule):
def import_graph_module(self, gm: GraphModule) -> Operation:
"""Low-level import of a GraphModule assuming that it has been functionalized.
TODO: This mechanism is deprecated by the `import_program` entry-point and
it should be removed when no longer required for backwards compatibility.
"""
self.import_stateless_graph(gm.graph)
return self.import_stateless_graph(gm.graph)
def import_stateless_graph(self, g: Graph, func_name: str = "main"):
def import_stateless_graph(
self,
g: Graph,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Low-level import of a functionalized, assumed stateless Graph as a func.
TODO: This mechanism is deprecated by the `import_program` entry-point and
@ -775,6 +796,7 @@ class FxImporter:
func_name,
ftype,
ip=self._m_ip,
visibility=func_visibility,
)
entry_block = Block.create_at_start(func.body, ftype.inputs)
node_importer = GraphNodeImporter(
@ -785,6 +807,7 @@ class FxImporter:
)
node_importer.import_nodes(g.nodes)
self.symbol_table.insert(func)
return func
def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
"""Extracts function metadata from the Graph.