[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
2024-04-29 10:09:00 +08:00
|
|
|
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
2024-02-07 11:07:59 +08:00
|
|
|
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
import warnings
|
|
|
|
|
2024-02-07 11:07:59 +08:00
|
|
|
import torch
|
|
|
|
import torch.export
|
|
|
|
import torch.nn as nn
|
2024-04-17 13:36:07 +08:00
|
|
|
from torch.export import ExportedProgram
|
2024-02-07 11:07:59 +08:00
|
|
|
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
2024-02-07 11:07:59 +08:00
|
|
|
from torch_mlir import ir
|
|
|
|
from torch_mlir.dialects import torch as torch_d
|
2024-02-15 13:00:52 +08:00
|
|
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
2024-02-07 11:07:59 +08:00
|
|
|
|
2024-03-27 08:06:05 +08:00
|
|
|
|
2024-02-07 11:07:59 +08:00
|
|
|
def export_and_import(
|
2024-04-17 13:36:07 +08:00
|
|
|
f: Union[nn.Module, ExportedProgram],
|
2024-02-07 11:07:59 +08:00
|
|
|
*args,
|
|
|
|
fx_importer: Optional[FxImporter] = None,
|
2024-03-15 01:26:34 +08:00
|
|
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
experimental_support_mutation: bool = False,
|
|
|
|
hooks: Optional[FxImporterHooks] = None,
|
2024-04-29 10:09:00 +08:00
|
|
|
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
2024-02-27 02:08:14 +08:00
|
|
|
func_name: str = "main",
|
2024-04-16 14:14:19 +08:00
|
|
|
enable_graph_printing: bool = False,
|
2024-02-07 11:07:59 +08:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
context = ir.Context()
|
|
|
|
torch_d.register_dialect(context)
|
|
|
|
|
|
|
|
if fx_importer is None:
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
2024-04-17 13:36:07 +08:00
|
|
|
if isinstance(f, ExportedProgram):
|
|
|
|
prog = f
|
|
|
|
else:
|
|
|
|
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
|
2024-03-27 08:06:05 +08:00
|
|
|
if decomposition_table is None:
|
|
|
|
decomposition_table = get_decomposition_table()
|
|
|
|
if decomposition_table:
|
|
|
|
prog = prog.run_decompositions(decomposition_table)
|
2024-04-16 14:14:19 +08:00
|
|
|
if enable_graph_printing:
|
|
|
|
prog.graph_module.print_readable()
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
if experimental_support_mutation:
|
|
|
|
if torch.__version__ < "2.3.0.dev20240207":
|
|
|
|
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
2024-02-27 02:08:14 +08:00
|
|
|
fx_importer.import_program(prog, func_name=func_name)
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
else:
|
2024-02-27 02:08:14 +08:00
|
|
|
fx_importer.import_frozen_program(prog, func_name=func_name)
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
|
2024-03-22 05:44:54 +08:00
|
|
|
return fx_importer.module
|
|
|
|
|
|
|
|
|
|
|
|
def stateless_fx_import(
|
|
|
|
gm: torch.fx.GraphModule,
|
|
|
|
fx_importer: Optional[FxImporter] = None,
|
|
|
|
hooks: Optional[FxImporterHooks] = None,
|
|
|
|
model_name: str = "main",
|
2024-04-16 14:14:19 +08:00
|
|
|
enable_graph_printing: bool = False,
|
2024-03-22 05:44:54 +08:00
|
|
|
):
|
2024-04-16 14:14:19 +08:00
|
|
|
if enable_graph_printing:
|
|
|
|
gm.print_readable()
|
2024-03-22 05:44:54 +08:00
|
|
|
context = ir.Context()
|
|
|
|
torch_d.register_dialect(context)
|
|
|
|
if fx_importer is None:
|
|
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
|
|
|
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
|
|
|
return fx_importer.module
|