[fx] Accept `func_visibility=` and return created func op. (#3054)

This is a partial landing of #3046 while waiting for an upstream change
for the rest of it.
pull/3061/head
Stella Laurenzo 2024-03-25 16:48:06 -07:00 committed by GitHub
parent 9ae33e482e
commit 17eeac880a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 33 additions and 10 deletions

View File

@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
if sparsity.layout is torch.sparse_coo: if sparsity.layout is torch.sparse_coo:
assert sparse_dim >= 2 and blocksize is None assert sparse_dim >= 2 and blocksize is None
trail_dim = batch_dim + sparse_dim - 1 trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim)) coords = ",".join(
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
)
sep = "," if sparse_dim > 2 else "" sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr: elif sparsity.layout is torch.sparse_csr:
@ -467,8 +469,12 @@ class FxImporter:
return self._m.operation return self._m.operation
def import_program( def import_program(
self, prog: torch.export.ExportedProgram, *, func_name: str = "main" self,
): prog: torch.export.ExportedProgram,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Imports an ExportedProgram according to our chosen canonical representation. """Imports an ExportedProgram according to our chosen canonical representation.
This mechanism is the fully general solution for handling an ExportedProgram This mechanism is the fully general solution for handling an ExportedProgram
@ -628,7 +634,9 @@ class FxImporter:
# Create the function. # Create the function.
with loc: with loc:
func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip) func_op = func_dialect.FuncOp(
func_name, ftype, ip=self._m_ip, visibility=func_visibility
)
entry_block = Block.create_at_start(func_op.body, ftype.inputs) entry_block = Block.create_at_start(func_op.body, ftype.inputs)
node_importer = GraphNodeImporter( node_importer = GraphNodeImporter(
@ -668,10 +676,15 @@ class FxImporter:
) )
node_importer.return_node_values(loc, user_outputs) node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op) self.symbol_table.insert(func_op)
return func_op
def import_frozen_program( def import_frozen_program(
self, prog: torch.export.ExportedProgram, func_name: str = "main" self,
): prog: torch.export.ExportedProgram,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Imports a consolidated torch.export.ExportedProgram instance. """Imports a consolidated torch.export.ExportedProgram instance.
If using the new torch.export path (vs a lower level precursor), then this is If using the new torch.export path (vs a lower level precursor), then this is
@ -750,17 +763,25 @@ class FxImporter:
node.replace_all_uses_with(replacement) node.replace_all_uses_with(replacement)
g.erase_node(node) g.erase_node(node)
self.import_stateless_graph(g, func_name) return self.import_stateless_graph(
g, func_name=func_name, func_visibility=func_visibility
)
def import_graph_module(self, gm: GraphModule): def import_graph_module(self, gm: GraphModule) -> Operation:
"""Low-level import of a GraphModule assuming that it has been functionalized. """Low-level import of a GraphModule assuming that it has been functionalized.
TODO: This mechanism is deprecated by the `import_program` entry-point and TODO: This mechanism is deprecated by the `import_program` entry-point and
it should be removed when no longer required for backwards compatibility. it should be removed when no longer required for backwards compatibility.
""" """
self.import_stateless_graph(gm.graph) return self.import_stateless_graph(gm.graph)
def import_stateless_graph(self, g: Graph, func_name: str = "main"): def import_stateless_graph(
self,
g: Graph,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Low-level import of a functionalized, assumed stateless Graph as a func. """Low-level import of a functionalized, assumed stateless Graph as a func.
TODO: This mechanism is deprecated by the `import_program` entry-point and TODO: This mechanism is deprecated by the `import_program` entry-point and
@ -775,6 +796,7 @@ class FxImporter:
func_name, func_name,
ftype, ftype,
ip=self._m_ip, ip=self._m_ip,
visibility=func_visibility,
) )
entry_block = Block.create_at_start(func.body, ftype.inputs) entry_block = Block.create_at_start(func.body, ftype.inputs)
node_importer = GraphNodeImporter( node_importer = GraphNodeImporter(
@ -785,6 +807,7 @@ class FxImporter:
) )
node_importer.import_nodes(g.nodes) node_importer.import_nodes(g.nodes)
self.symbol_table.insert(func) self.symbol_table.insert(func)
return func
def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
"""Extracts function metadata from the Graph. """Extracts function metadata from the Graph.