[fx] Use module Operations instead of Module.

This was only used in certain advanced uses of the API that want to build into their own module. The MLIR `Module` class is an awkward/restrictive way to require this to go as only some things will have it. Just switch everything to be based on a module `Operation`.
fx_use_module_op
Stella Laurenzo 2024-03-21 19:57:53 -07:00
parent 6ea857c644
commit 04685a98e8
1 changed files with 38 additions and 22 deletions

View File

@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
if sparsity.layout is torch.sparse_coo: if sparsity.layout is torch.sparse_coo:
assert sparse_dim >= 2 and blocksize is None assert sparse_dim >= 2 and blocksize is None
trail_dim = batch_dim + sparse_dim - 1 trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim)) coords = ",".join(
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
)
sep = "," if sparse_dim > 2 else "" sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr: elif sparsity.layout is torch.sparse_csr:
@ -415,7 +417,7 @@ class FxImporter:
__slots__ = [ __slots__ = [
"_c", "_c",
"_cc", "_cc",
"_m", "_m_op",
"_m_ip", "_m_ip",
"_py_attr_tracker", "_py_attr_tracker",
"_hooks", "_hooks",
@ -425,28 +427,31 @@ class FxImporter:
def __init__( def __init__(
self, self,
*, *,
module: Optional[Module] = None, module_op: Optional[Operation] = None,
context: Optional[Context] = None, context: Optional[Context] = None,
config_check: bool = True, config_check: bool = True,
py_attr_tracker: Optional["RefTracker"] = None, py_attr_tracker: Optional["RefTracker"] = None,
hooks: Optional[FxImporterHooks] = None, hooks: Optional[FxImporterHooks] = None,
): ):
if module is not None: if module_op is not None:
assert context is None, "If configuring with a Module, context must be None" assert (
self._m = module context is None
self._c = self.module.context ), "If configuring with a module op, context must be None"
self._m_op = module_op
self._c = self._m_op.context
else: else:
self._c = context if context else Context() self._c = context if context else Context()
self._m = Module.create(Location.unknown(self._c)) self._m_op = Module.create(Location.unknown(self._c)).operation
body = self._m_op.regions[0].blocks[0]
if config_check: if config_check:
# Production code can disable this for a bit of a boost. # Production code can disable this for a bit of a boost.
self._config_check() self._config_check()
self._py_attr_tracker = py_attr_tracker or RefTracker() self._py_attr_tracker = py_attr_tracker or RefTracker()
self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker) self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker)
self._m_ip = InsertionPoint(self._m.body) self._m_ip = InsertionPoint(body)
self._hooks = hooks or FxImporterHooks() self._hooks = hooks or FxImporterHooks()
self.symbol_table = SymbolTable(self._m.operation) self.symbol_table = SymbolTable(self._m_op)
self._hooks.prepare_module(self._m.operation) self._hooks.prepare_module(self._m_op)
def _config_check(self): def _config_check(self):
for dname in REQUIRED_DIALCTS: for dname in REQUIRED_DIALCTS:
@ -458,17 +463,17 @@ class FxImporter:
f"The MLIR context {self._c} is missing required dialect '{dname}'" f"The MLIR context {self._c} is missing required dialect '{dname}'"
) )
@property
def module(self) -> Module:
return self._m
@property @property
def module_op(self) -> Operation: def module_op(self) -> Operation:
return self._m.operation return self._m_op
def import_program( def import_program(
self, prog: torch.export.ExportedProgram, *, func_name: str = "main" self,
): prog: torch.export.ExportedProgram,
*,
func_name: str = "main",
func_visibility: Optional[str] = None,
) -> Operation:
"""Imports an ExportedProgram according to our chosen canonical representation. """Imports an ExportedProgram according to our chosen canonical representation.
This mechanism is the fully general solution for handling an ExportedProgram This mechanism is the fully general solution for handling an ExportedProgram
@ -490,6 +495,8 @@ class FxImporter:
It is recommended that integrators subclass and override the `resolve_literal` It is recommended that integrators subclass and override the `resolve_literal`
method to control access to mutable buffers and parameters. Without that, the method to control access to mutable buffers and parameters. Without that, the
default policy is to capture them as frozen values. default policy is to capture them as frozen values.
Returns the created entry function as a generic Operation.
""" """
# Create lookaside table of placeholders/outputs. # Create lookaside table of placeholders/outputs.
placeholder_nodes: Dict[str, Node] = {} placeholder_nodes: Dict[str, Node] = {}
@ -628,7 +635,9 @@ class FxImporter:
# Create the function. # Create the function.
with loc: with loc:
func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip) func_op = func_dialect.FuncOp(
func_name, ftype, ip=self._m_ip, visibility=func_visibility
)
entry_block = Block.create_at_start(func_op.body, ftype.inputs) entry_block = Block.create_at_start(func_op.body, ftype.inputs)
node_importer = GraphNodeImporter( node_importer = GraphNodeImporter(
@ -668,9 +677,13 @@ class FxImporter:
) )
node_importer.return_node_values(loc, user_outputs) node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op) self.symbol_table.insert(func_op)
return func_op.operation
def import_frozen_program( def import_frozen_program(
self, prog: torch.export.ExportedProgram, func_name: str = "main" self,
prog: torch.export.ExportedProgram,
func_name: str = "main",
func_visibility: Optional[str] = None,
): ):
"""Imports a consolidated torch.export.ExportedProgram instance. """Imports a consolidated torch.export.ExportedProgram instance.
@ -750,7 +763,7 @@ class FxImporter:
node.replace_all_uses_with(replacement) node.replace_all_uses_with(replacement)
g.erase_node(node) g.erase_node(node)
self.import_stateless_graph(g, func_name) self.import_stateless_graph(g, func_name, func_visibility=func_visibility)
def import_graph_module(self, gm: GraphModule): def import_graph_module(self, gm: GraphModule):
"""Low-level import of a GraphModule assuming that it has been functionalized. """Low-level import of a GraphModule assuming that it has been functionalized.
@ -760,7 +773,9 @@ class FxImporter:
""" """
self.import_stateless_graph(gm.graph) self.import_stateless_graph(gm.graph)
def import_stateless_graph(self, g: Graph, func_name: str = "main"): def import_stateless_graph(
self, g: Graph, func_name: str = "main", func_visibility: Optional[str] = None
):
"""Low-level import of a functionalized, assumed stateless Graph as a func. """Low-level import of a functionalized, assumed stateless Graph as a func.
TODO: This mechanism is deprecated by the `import_program` entry-point and TODO: This mechanism is deprecated by the `import_program` entry-point and
@ -775,6 +790,7 @@ class FxImporter:
func_name, func_name,
ftype, ftype,
ip=self._m_ip, ip=self._m_ip,
func_visibility=func_visibility,
) )
entry_block = Block.create_at_start(func.body, ftype.inputs) entry_block = Block.create_at_start(func.body, ftype.inputs)
node_importer = GraphNodeImporter( node_importer = GraphNodeImporter(