[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
2024-04-29 10:09:00 +08:00
|
|
|
from typing import Optional, Union, Dict, Tuple, Any, Callable
|
2024-08-01 10:52:41 +08:00
|
|
|
from packaging import version
|
2024-02-07 11:07:59 +08:00
|
|
|
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
import warnings
|
|
|
|
|
2024-02-07 11:07:59 +08:00
|
|
|
import torch
|
|
|
|
import torch.export
|
|
|
|
import torch.nn as nn
|
2024-04-17 13:36:07 +08:00
|
|
|
from torch.export import ExportedProgram
|
2024-02-07 11:07:59 +08:00
|
|
|
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
|
2024-02-07 11:07:59 +08:00
|
|
|
from torch_mlir import ir
|
|
|
|
from torch_mlir.dialects import torch as torch_d
|
2024-02-15 13:00:52 +08:00
|
|
|
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
|
2024-05-07 20:58:50 +08:00
|
|
|
from torch_mlir.compiler_utils import (
|
|
|
|
OutputType,
|
|
|
|
run_pipeline_with_repro_report,
|
|
|
|
lower_mlir_module,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _module_lowering(
|
|
|
|
verbose,
|
|
|
|
output_type,
|
|
|
|
torch_mod,
|
|
|
|
extra_library_file_name=None,
|
|
|
|
):
|
|
|
|
|
2024-07-15 01:33:47 +08:00
|
|
|
if output_type == OutputType.RAW:
|
2024-05-07 20:58:50 +08:00
|
|
|
if verbose:
|
|
|
|
print(torch_mod)
|
|
|
|
return torch_mod
|
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376)
Discord Thread:
https://discord.com/channels/636084430946959380/1238330633328005243
## Context:
[This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61)
was updated to support e2e tests for the TorchDynamo frontend in
Torch-MLIR, where we run FX decompositions and import the FX IR to
generate Torch dialect, followed by
`torch-function-to-torch-backend-pipeline`, skipping only the shape/type
refinement for now. However, we should be able to skip many of the torch
simplification passes, as depicted in the [frontend
roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png).
Based on IREE's TorchDynamo
[pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29),
the only two passes we seem to require are: `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. This is inline with our findings as well
based on initial exploration.
This PR creates a dedicated frontend simplification pipeline for
TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to
ensure we're not regressing by removing many of the passes that were
historically needed for TorchScript.
One notable change here is that we do not call the
`LowerToBackendContractPass` anymore, which used to call
`TorchSimplificationPipeline` iteratively until VerifyBackendContract
was clean. Some of this was required for the shape/type refinement to
converge, which seems a non-issue for Dynamo frontend. Do we anticipate
this (the iterative invocation of TorchSimplificationPipeline followed
by VerifyBackendContract) to be worth retaining in the Dynamo frontend
pipeline? If so, I can make those changes, PLMK.
2024-05-22 20:23:18 +08:00
|
|
|
# TODO: pass extra_library_file_name by caller
|
2024-05-07 20:58:50 +08:00
|
|
|
if extra_library_file_name is None:
|
|
|
|
extra_library_file_name = ""
|
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376)
Discord Thread:
https://discord.com/channels/636084430946959380/1238330633328005243
## Context:
[This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61)
was updated to support e2e tests for the TorchDynamo frontend in
Torch-MLIR, where we run FX decompositions and import the FX IR to
generate Torch dialect, followed by
`torch-function-to-torch-backend-pipeline`, skipping only the shape/type
refinement for now. However, we should be able to skip many of the torch
simplification passes, as depicted in the [frontend
roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png).
Based on IREE's TorchDynamo
[pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29),
the only two passes we seem to require are: `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. This is inline with our findings as well
based on initial exploration.
This PR creates a dedicated frontend simplification pipeline for
TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to
ensure we're not regressing by removing many of the passes that were
historically needed for TorchScript.
One notable change here is that we do not call the
`LowerToBackendContractPass` anymore, which used to call
`TorchSimplificationPipeline` iteratively until VerifyBackendContract
was clean. Some of this was required for the shape/type refinement to
converge, which seems a non-issue for Dynamo frontend. Do we anticipate
this (the iterative invocation of TorchSimplificationPipeline followed
by VerifyBackendContract) to be worth retaining in the Dynamo frontend
pipeline? If so, I can make those changes, PLMK.
2024-05-22 20:23:18 +08:00
|
|
|
option_string = "{extra-library=" + extra_library_file_name + "}"
|
2024-05-07 20:58:50 +08:00
|
|
|
run_pipeline_with_repro_report(
|
|
|
|
torch_mod,
|
[fx] Fix importing and tests for quantized conv (#3809)
The fx tracer does not support tracing "real" quantized tensors
currently. A "real" quantized tensor here means a tensor that is created
using a method like `torch.quantize_per_tensor()` and carries the
quantization parameters (scale, zero_point, scheme) in the object.
However, it seems like the DQ-Q type fake quantizatation is now commonly
used as a high level representation of quantized operators and is only
lowered to native quantized ops (if available) in the respective
hardware backend. Quantization of floating point modules in PyTorch is
recently also performed as a graph transformation after
exporting/tracing the original module.
```python
# Examples of "real"/native quantization
tens = torch.randint(-127, 127, (1,), dtype=torch.int8)
torch._make_per_tensor_quantized_tensor(tens, 1, 0)
# tensor([90.], size=(1,), dtype=torch.qint8,
# quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)
tens = torch.rand((1,))
torch.quantize_per_tensor(tens, 1, 0, torch.qint8)
# tensor([1.], size=(1,), dtype=torch.qint8,
# quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)
# Example of DQ/Q quantization
import torch.ao.quantization.fx._decomposed
tens = torch.rand((1,))
torch.ops.quantized_decomposed.quantize_per_tensor.default(tens, 1, 0, -128, 127, torch.int8)
# tensor([1], dtype=torch.int8)
```
This means that a typical import flow for a quantized network
into/through torch-mlir would look like this:
`torch.export() -> quantization transformations on fx graph ->
fx_importer` Where the tensors in the graph are normal float/int tensors
and the quantization parameters are carried by the DQ/Q ops. These kinds
of graphs can be traced without issues.
Currently, our quantized convolution tests use the "real" quantized
tensors. This means that with the retirement of the `jit_ir_importer`,
these tests cannot be imported any longer. In summary, I see no reason
to stick to the "real" quantization in these tests, as both PyTorch 2.0
is using DQ/Q quantization and our linalg backend is also using it.
This patch updates our quantized convolution tests to use the DQ-Q
quantization with the ops from `torch.ops.quantized_decomposed`.
Note: For future reference, there seems to be an ongoing consolidation
of the ops for the DQ/Q scheme on the PyTorch side
(https://github.com/pytorch/ao/issues/986#issuecomment-2390296826).
2024-10-23 00:37:57 +08:00
|
|
|
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
|
2024-05-07 20:58:50 +08:00
|
|
|
"Lowering TorchFX IR -> Torch Backend IR",
|
|
|
|
enable_ir_printing=verbose,
|
|
|
|
)
|
|
|
|
return lower_mlir_module(verbose, output_type, torch_mod)
|
2024-02-07 11:07:59 +08:00
|
|
|
|
2024-03-27 08:06:05 +08:00
|
|
|
|
2024-02-07 11:07:59 +08:00
|
|
|
def export_and_import(
|
2024-04-17 13:36:07 +08:00
|
|
|
f: Union[nn.Module, ExportedProgram],
|
2024-02-07 11:07:59 +08:00
|
|
|
*args,
|
2024-07-15 01:33:47 +08:00
|
|
|
output_type: Union[str, OutputType] = OutputType.RAW,
|
2024-02-07 11:07:59 +08:00
|
|
|
fx_importer: Optional[FxImporter] = None,
|
2024-03-15 01:26:34 +08:00
|
|
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
experimental_support_mutation: bool = False,
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
import_symbolic_shape_expressions: bool = False,
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
hooks: Optional[FxImporterHooks] = None,
|
2024-04-29 10:09:00 +08:00
|
|
|
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
|
2024-02-27 02:08:14 +08:00
|
|
|
func_name: str = "main",
|
2024-04-16 14:14:19 +08:00
|
|
|
enable_graph_printing: bool = False,
|
2024-05-07 20:58:50 +08:00
|
|
|
enable_ir_printing: bool = False,
|
2024-02-07 11:07:59 +08:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
context = ir.Context()
|
|
|
|
torch_d.register_dialect(context)
|
|
|
|
|
|
|
|
if fx_importer is None:
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
2024-04-17 13:36:07 +08:00
|
|
|
if isinstance(f, ExportedProgram):
|
|
|
|
prog = f
|
|
|
|
else:
|
2024-08-01 10:52:41 +08:00
|
|
|
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
|
|
|
|
if version.Version(torch.__version__) >= version.Version("2.2.0"):
|
|
|
|
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
|
|
|
|
else:
|
|
|
|
prog = torch.export.export(f, args, kwargs)
|
2024-03-27 08:06:05 +08:00
|
|
|
if decomposition_table is None:
|
|
|
|
decomposition_table = get_decomposition_table()
|
|
|
|
if decomposition_table:
|
|
|
|
prog = prog.run_decompositions(decomposition_table)
|
2024-04-16 14:14:19 +08:00
|
|
|
if enable_graph_printing:
|
|
|
|
prog.graph_module.print_readable()
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
if experimental_support_mutation:
|
|
|
|
if torch.__version__ < "2.3.0.dev20240207":
|
|
|
|
warnings.warn("Mutable program import only supported on PyTorch 2.3+")
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
fx_importer.import_program(
|
|
|
|
prog,
|
|
|
|
func_name=func_name,
|
|
|
|
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
|
|
|
)
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
else:
|
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {
func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
%5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
%int1 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>
return %6 : !torch.vtensor<[?,?,3],f32>
}
}
```
For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)
tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x); x = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)
sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y); y = None
# File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)
cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1); tanh = sigmoid = None
return (cat,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)}
```
Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.
- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 19:04:03 +08:00
|
|
|
fx_importer.import_frozen_program(
|
|
|
|
prog,
|
|
|
|
func_name=func_name,
|
|
|
|
import_symbolic_shape_expressions=import_symbolic_shape_expressions,
|
|
|
|
)
|
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.
With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.
The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).
Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.
Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-17 01:46:30 +08:00
|
|
|
|
2024-05-07 20:58:50 +08:00
|
|
|
return _module_lowering(
|
|
|
|
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
|
|
|
)
|
2024-03-22 05:44:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
def stateless_fx_import(
|
|
|
|
gm: torch.fx.GraphModule,
|
2024-07-15 01:33:47 +08:00
|
|
|
output_type: Union[str, OutputType] = OutputType.RAW,
|
2024-03-22 05:44:54 +08:00
|
|
|
fx_importer: Optional[FxImporter] = None,
|
|
|
|
hooks: Optional[FxImporterHooks] = None,
|
|
|
|
model_name: str = "main",
|
2024-04-16 14:14:19 +08:00
|
|
|
enable_graph_printing: bool = False,
|
2024-05-07 20:58:50 +08:00
|
|
|
enable_ir_printing: bool = False,
|
2024-03-22 05:44:54 +08:00
|
|
|
):
|
2024-04-16 14:14:19 +08:00
|
|
|
if enable_graph_printing:
|
|
|
|
gm.print_readable()
|
2024-03-22 05:44:54 +08:00
|
|
|
context = ir.Context()
|
|
|
|
torch_d.register_dialect(context)
|
|
|
|
if fx_importer is None:
|
|
|
|
fx_importer = FxImporter(context=context, hooks=hooks)
|
|
|
|
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
|
2024-05-07 20:58:50 +08:00
|
|
|
return _module_lowering(
|
|
|
|
enable_ir_printing, OutputType.get(output_type), fx_importer.module
|
|
|
|
)
|