mirror of https://github.com/llvm/torch-mlir
[fx] Use module Operations instead of Module.
This was only used in certain advanced uses of the API that want to build into their own module. The MLIR `Module` class is an awkward/restrictive way to require this to go as only some things will have it. Just switch everything to be based on a module `Operation`.fx_use_module_op
parent
6ea857c644
commit
04685a98e8
|
@ -302,7 +302,9 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
||||||
if sparsity.layout is torch.sparse_coo:
|
if sparsity.layout is torch.sparse_coo:
|
||||||
assert sparse_dim >= 2 and blocksize is None
|
assert sparse_dim >= 2 and blocksize is None
|
||||||
trail_dim = batch_dim + sparse_dim - 1
|
trail_dim = batch_dim + sparse_dim - 1
|
||||||
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim))
|
coords = ",".join(
|
||||||
|
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
|
||||||
|
)
|
||||||
sep = "," if sparse_dim > 2 else ""
|
sep = "," if sparse_dim > 2 else ""
|
||||||
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
|
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
|
||||||
elif sparsity.layout is torch.sparse_csr:
|
elif sparsity.layout is torch.sparse_csr:
|
||||||
|
@ -415,7 +417,7 @@ class FxImporter:
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"_c",
|
"_c",
|
||||||
"_cc",
|
"_cc",
|
||||||
"_m",
|
"_m_op",
|
||||||
"_m_ip",
|
"_m_ip",
|
||||||
"_py_attr_tracker",
|
"_py_attr_tracker",
|
||||||
"_hooks",
|
"_hooks",
|
||||||
|
@ -425,28 +427,31 @@ class FxImporter:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
module: Optional[Module] = None,
|
module_op: Optional[Operation] = None,
|
||||||
context: Optional[Context] = None,
|
context: Optional[Context] = None,
|
||||||
config_check: bool = True,
|
config_check: bool = True,
|
||||||
py_attr_tracker: Optional["RefTracker"] = None,
|
py_attr_tracker: Optional["RefTracker"] = None,
|
||||||
hooks: Optional[FxImporterHooks] = None,
|
hooks: Optional[FxImporterHooks] = None,
|
||||||
):
|
):
|
||||||
if module is not None:
|
if module_op is not None:
|
||||||
assert context is None, "If configuring with a Module, context must be None"
|
assert (
|
||||||
self._m = module
|
context is None
|
||||||
self._c = self.module.context
|
), "If configuring with a module op, context must be None"
|
||||||
|
self._m_op = module_op
|
||||||
|
self._c = self._m_op.context
|
||||||
else:
|
else:
|
||||||
self._c = context if context else Context()
|
self._c = context if context else Context()
|
||||||
self._m = Module.create(Location.unknown(self._c))
|
self._m_op = Module.create(Location.unknown(self._c)).operation
|
||||||
|
body = self._m_op.regions[0].blocks[0]
|
||||||
if config_check:
|
if config_check:
|
||||||
# Production code can disable this for a bit of a boost.
|
# Production code can disable this for a bit of a boost.
|
||||||
self._config_check()
|
self._config_check()
|
||||||
self._py_attr_tracker = py_attr_tracker or RefTracker()
|
self._py_attr_tracker = py_attr_tracker or RefTracker()
|
||||||
self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker)
|
self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker)
|
||||||
self._m_ip = InsertionPoint(self._m.body)
|
self._m_ip = InsertionPoint(body)
|
||||||
self._hooks = hooks or FxImporterHooks()
|
self._hooks = hooks or FxImporterHooks()
|
||||||
self.symbol_table = SymbolTable(self._m.operation)
|
self.symbol_table = SymbolTable(self._m_op)
|
||||||
self._hooks.prepare_module(self._m.operation)
|
self._hooks.prepare_module(self._m_op)
|
||||||
|
|
||||||
def _config_check(self):
|
def _config_check(self):
|
||||||
for dname in REQUIRED_DIALCTS:
|
for dname in REQUIRED_DIALCTS:
|
||||||
|
@ -458,17 +463,17 @@ class FxImporter:
|
||||||
f"The MLIR context {self._c} is missing required dialect '{dname}'"
|
f"The MLIR context {self._c} is missing required dialect '{dname}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def module(self) -> Module:
|
|
||||||
return self._m
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module_op(self) -> Operation:
|
def module_op(self) -> Operation:
|
||||||
return self._m.operation
|
return self._m_op
|
||||||
|
|
||||||
def import_program(
|
def import_program(
|
||||||
self, prog: torch.export.ExportedProgram, *, func_name: str = "main"
|
self,
|
||||||
):
|
prog: torch.export.ExportedProgram,
|
||||||
|
*,
|
||||||
|
func_name: str = "main",
|
||||||
|
func_visibility: Optional[str] = None,
|
||||||
|
) -> Operation:
|
||||||
"""Imports an ExportedProgram according to our chosen canonical representation.
|
"""Imports an ExportedProgram according to our chosen canonical representation.
|
||||||
|
|
||||||
This mechanism is the fully general solution for handling an ExportedProgram
|
This mechanism is the fully general solution for handling an ExportedProgram
|
||||||
|
@ -490,6 +495,8 @@ class FxImporter:
|
||||||
It is recommended that integrators subclass and override the `resolve_literal`
|
It is recommended that integrators subclass and override the `resolve_literal`
|
||||||
method to control access to mutable buffers and parameters. Without that, the
|
method to control access to mutable buffers and parameters. Without that, the
|
||||||
default policy is to capture them as frozen values.
|
default policy is to capture them as frozen values.
|
||||||
|
|
||||||
|
Returns the created entry function as a generic Operation.
|
||||||
"""
|
"""
|
||||||
# Create lookaside table of placeholders/outputs.
|
# Create lookaside table of placeholders/outputs.
|
||||||
placeholder_nodes: Dict[str, Node] = {}
|
placeholder_nodes: Dict[str, Node] = {}
|
||||||
|
@ -628,7 +635,9 @@ class FxImporter:
|
||||||
|
|
||||||
# Create the function.
|
# Create the function.
|
||||||
with loc:
|
with loc:
|
||||||
func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip)
|
func_op = func_dialect.FuncOp(
|
||||||
|
func_name, ftype, ip=self._m_ip, visibility=func_visibility
|
||||||
|
)
|
||||||
entry_block = Block.create_at_start(func_op.body, ftype.inputs)
|
entry_block = Block.create_at_start(func_op.body, ftype.inputs)
|
||||||
|
|
||||||
node_importer = GraphNodeImporter(
|
node_importer = GraphNodeImporter(
|
||||||
|
@ -668,9 +677,13 @@ class FxImporter:
|
||||||
)
|
)
|
||||||
node_importer.return_node_values(loc, user_outputs)
|
node_importer.return_node_values(loc, user_outputs)
|
||||||
self.symbol_table.insert(func_op)
|
self.symbol_table.insert(func_op)
|
||||||
|
return func_op.operation
|
||||||
|
|
||||||
def import_frozen_program(
|
def import_frozen_program(
|
||||||
self, prog: torch.export.ExportedProgram, func_name: str = "main"
|
self,
|
||||||
|
prog: torch.export.ExportedProgram,
|
||||||
|
func_name: str = "main",
|
||||||
|
func_visibility: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Imports a consolidated torch.export.ExportedProgram instance.
|
"""Imports a consolidated torch.export.ExportedProgram instance.
|
||||||
|
|
||||||
|
@ -750,7 +763,7 @@ class FxImporter:
|
||||||
node.replace_all_uses_with(replacement)
|
node.replace_all_uses_with(replacement)
|
||||||
g.erase_node(node)
|
g.erase_node(node)
|
||||||
|
|
||||||
self.import_stateless_graph(g, func_name)
|
self.import_stateless_graph(g, func_name, func_visibility=func_visibility)
|
||||||
|
|
||||||
def import_graph_module(self, gm: GraphModule):
|
def import_graph_module(self, gm: GraphModule):
|
||||||
"""Low-level import of a GraphModule assuming that it has been functionalized.
|
"""Low-level import of a GraphModule assuming that it has been functionalized.
|
||||||
|
@ -760,7 +773,9 @@ class FxImporter:
|
||||||
"""
|
"""
|
||||||
self.import_stateless_graph(gm.graph)
|
self.import_stateless_graph(gm.graph)
|
||||||
|
|
||||||
def import_stateless_graph(self, g: Graph, func_name: str = "main"):
|
def import_stateless_graph(
|
||||||
|
self, g: Graph, func_name: str = "main", func_visibility: Optional[str] = None
|
||||||
|
):
|
||||||
"""Low-level import of a functionalized, assumed stateless Graph as a func.
|
"""Low-level import of a functionalized, assumed stateless Graph as a func.
|
||||||
|
|
||||||
TODO: This mechanism is deprecated by the `import_program` entry-point and
|
TODO: This mechanism is deprecated by the `import_program` entry-point and
|
||||||
|
@ -775,6 +790,7 @@ class FxImporter:
|
||||||
func_name,
|
func_name,
|
||||||
ftype,
|
ftype,
|
||||||
ip=self._m_ip,
|
ip=self._m_ip,
|
||||||
|
func_visibility=func_visibility,
|
||||||
)
|
)
|
||||||
entry_block = Block.create_at_start(func.body, ftype.inputs)
|
entry_block = Block.create_at_start(func.body, ftype.inputs)
|
||||||
node_importer = GraphNodeImporter(
|
node_importer = GraphNodeImporter(
|
||||||
|
|
Loading…
Reference in New Issue