Commit Graph

55 Commits (c7d52f63b482b2c30f4efb435ce0cc2efeab25d9)

Author SHA1 Message Date
Andrea 🦈 51902ec2dc
Create MLIR functions for ONNX operators that are functions (#3409)
Resolves #3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operator is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operators that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only one
operator is allowlisted for expansion, MeanVarianceNormalization.
However, many other operators can be correctly expanded by the current
code, so hopefully the allowlist can be gradually extended. It is
possible to disable the allowlist in the configuration, in which case
all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
2024-06-14 10:11:26 -07:00
Wu Yuan a02e14e971
[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456) 2024-06-14 10:52:09 +08:00
zjgarvey c0eb6d89c0
[ONNX] add some args to the onnx importer to assist shape_inference (#3445)
Adds the following arguments:
- "--clear-domain": enabling this flag (default False) will delete the
domain attribute from each node in the onnx model before importing.
Shape inference does not seem to work for onnx ops in custom domains. In
the rare case when these ops have a corresponding counterpart in base
onnx, enabling this flag might allow shape inference to work properly.
- "--opset-version": allows setting the opset version manually. This
will cause the importer to attempt to update the opset_version of the
onnx model before importing. Newer opset versions sometimes have more
robust shape inference patterns.
2024-06-12 10:55:14 -05:00
Sambhav Jain 7e0e23c668
Test custom op import with symbolic shapes (#3431)
Tests the basic constructs of registering a custom op and its abstract
implementations (with FakeTensors) in python, going through TorchDynamo
export, followed by importing the shape expressions in the Torch
dialect.

Also fixes the importer were previously the symbolic bind op insertion
was not gated in one place.
2024-06-09 00:32:49 -07:00
Rob Suderman 7f188eb824
Add f8 types to fx importer (#3434)
Missing types for tracing float8 types.
2024-06-07 13:58:18 -07:00
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Rob Suderman 0a6861b1e8
Add conversion operation for bool resolved_literal (#3410)
Resolving `bool` literals can result in a type change to uint8. This
needs to be converted back to the expected type before returning to the
wrapped `torch` operators.
2024-06-03 14:43:38 -07:00
penguin_wwy a5d3b546f8
[FxImporter] Fix embedding bag (#3387) 2024-05-29 14:46:21 +08:00
penguin_wwy d924d0047f
[FxImporter] Fix primitive type in return (#3379) 2024-05-23 09:55:33 +08:00
penguin_wwy 972d47b586
[FxImporter] Fix constant bool tensor (#3375) 2024-05-22 22:59:01 +08:00
penguin_wwy c2c1c2cfa4
[FxImporter] Fix failed e2e case (#3365) 2024-05-22 00:20:54 +08:00
Stella Laurenzo 00efec0b73
[linalg] Implement strict mode lowering for aten.view. (#3319)
* Enables assume_strict_symbolic_shapes on fx_importer imported
programs, indicating strict shape semantics.
* Reworks the view->reshape lowering to take advantage of strict mode
and do one of:
  * Collapse to 0D
  * Flatten/Unflatten when there is an inferred dim.
  * Fallback to tensor.reshape
* Splits some test cases up and adds an attribute to control the old
pattern (so new corners can be tested in strict mode in isolation).
* Dynamic inferred mode needs upstream work to generalize expand_shape
(so that case is suppressed here).
* Deletes the assert from the existing tensor.reshape lowering if strict
shape mode is enabled (since the condition it is dynamically asserting
cannot happen).
2024-05-10 13:45:50 -07:00
penguin_wwy 64b59c7fc3
[FxImporter] Eliminate the dependency on the refinement pass (#3309) 2024-05-10 02:44:36 +08:00
Xida Ren (Cedar) 33eef15e42
Support onnx.If (#2825)
This is probably a decent PR for learning about blocks and regions.

If you're here to learn about that, consider also looking at
lib/Conversion/TorchToSCF/TorchToSCF.cpp

While this doesn't include an e2e test, it is tested downstream in
https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-04-30 18:36:40 +00:00
penguin_wwy b2185195e8
[NFC] Update black version (#3256)
* Update black version to support 3.11/3.12
* Reformat code
2024-04-29 11:06:01 +08:00
penguin_wwy 9f64748f97
[FxImporter] Synchronize the collection of symbolic torch ops (#3236) 2024-04-29 10:09:00 +08:00
Stella Laurenzo 6877302504
[NFC reformat] Applies pre-commit formatting to Python files. (#3244)
This is a large change because prior to this point, Python files in the
project were not consistently formatted. This reformats them all with
black defaults.

Based on experience with prior projects, if you have a dev/long-term
branch with Python patches, you can minimize merge conflicts prior to
rebasing to include this commit by running `black` on your modified
Python files, squashing, and then rebasing/merging.
2024-04-27 14:16:31 -07:00
penguin_wwy 3aa81f78d8
[FxImporter] Replace local_scalar_dense in fx_importer (#3180) 2024-04-17 22:45:47 +08:00
penguin_wwy 398aeeec87
[FxImporter] Fix kwarg operands in fx importer (#3166)
Remove the `kwarg_only` limitation, for example
```
torch.add(x, 3.0, alpha=2)
```
compiled to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int1
```
fix to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int2
```
2024-04-16 13:17:05 -07:00
penguin_wwy af5509c5d9
[FxImporter] Type conversion to resolve the mismatch between Py type and schema type (#3163) 2024-04-15 23:14:19 -07:00
Stella Laurenzo ffaaf08c31
[fx] Fix type inference for scalar/int types. (#3099)
This was discovered in a downstream test suite and was due to a control
flow nesting merge issue. In-tree test added and fixed.
2024-04-02 13:56:43 -07:00
penguin_wwy 5325d3e6e6
[fx] Fix type hint for fx importer (#3066)
Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
2024-04-01 17:31:43 -07:00
Rob Suderman ec4cb8be44
Bump LLVM to llvm/llvm-project@0030fc4ac7 (#3079)
Co-authored-by: Peiming Liu <peiming@google.com>
2024-04-01 16:34:59 -07:00
Stella Laurenzo 826786bdd0
[fx] Support ExportedProgram buffer mutation. (#3080)
In the prior state when I supported mutation of user inputs by treating
them as mutable-tensor SSA values, I had left the case of buffer
mutation only vaguely implemented until a concrete use emerged.
    
This patch reworks this buffer mutation support by assuming that buffers
must be resolved via the hooks symbolically and treated with load/store
semantics. This is implied in the structure since we have no SSA value
that represents a buffer and we already assume that reading parameters
happens via such a mechanism.
2024-04-01 14:18:12 -07:00
Stella Laurenzo 282e9b0e64
[fx] Fix type determination for multi-return ops and static `None` returns. (#3081)
In practice, this was caught by the way that AOT autograd traces
`convolution_backward`. For the unit test, we just repro it with a
custom op.
2024-04-01 09:39:38 -07:00
Stella Laurenzo e2343cf4ce
[fx] Implement auto_functionalized higher order op. (#3063)
* Also adds the basic scaffolding for handling more of these, which will
be needed for cond, while, etc.
* Refactors some of the support in the generic OpOverload emitter so it
can be shared with these other special forms.

This has been on my list for a while, but it just so happens that as
part of upgrading to PyTorch 2.3 and a pure upstream flow in Turbine, we
were using a feature that required integration with auto_functionalized.
This is perhaps the "weirdest" of the higher-order ops and a poor place
to start, but needs must. We have testing for this in Turbine.

Full support in Turbine has an entire custom ops facility. I've reduced
this down to a unit test in torch-mlir.
2024-03-26 17:06:05 -07:00
Stella Laurenzo 17eeac880a
[fx] Accept `func_visibility=` and return created func op. (#3054)
This is a partial landing of #3046 while waiting for an upstream change
for the rest of it.
2024-03-25 16:48:06 -07:00
Stella Laurenzo 6ea857c644
[fx] Make the lift_fresh_copy -> clone special form use kwargs. (#3045)
At some point, this op became kwarg-only instead of arg/kwarg.
Discovered when upgrading to PyTorch 2.3.

Also adds a test as this was untested in-tree (was caught out of tree).
2024-03-21 15:34:40 -07:00
Aart Bik fe59f1ee0d
[torch-mlir][sparse] higher dimension COO (#3042)
Lift this from 2-dim only to n-dim for n>=2
2024-03-19 15:59:07 -07:00
penguin_wwy f34c187ac4
Normalize type hints to be compatible with multiple Python versions (#3028)
Although we provide a wheel package for Python 3.8, it may actually
throw the following exception:
`TypeError: 'type' object is not subscriptable`
2024-03-15 08:29:48 -07:00
Daniel Garvey 80c7bc3f7a
fximporter: support newer torch versions (#2999)
uses version checking since attributes exist in both versions, the only
thing that changes is what we're receiving as an fx graph
2024-03-08 14:58:50 -06:00
Yuanqiang Liu 4d01b0f1a3
[FxImporter] remove dataclass slots to support python3.9 (#2974)
* that `dataclass`'s `slots` is supported after python 3.10.
2024-03-06 01:04:38 +08:00
Scott Todd e7d90a4b82
[onnx] Fix type on create_module() in onnx_importer.py. (#2968)
The type returned was changed in
https://github.com/llvm/torch-mlir/pull/2795. This led to errors in the
downstream IREE project: https://github.com/openxla/iree/pull/16622.
2024-02-29 13:01:13 -08:00
Peiming Liu e85a2a87c5
[torch-mlir][sparse] support e2e sparse kernels with COO inputs. (#2939) 2024-02-28 16:08:37 -08:00
Rob Suderman e48fe45886
[onnx] Import `onnx` import to pass remaining tests (#2951)
Finish supporting importing the vast majority of `onnx` operations. This
includes:
- region support
- region value inherentance
- `torch.string` support
- `torch.list` support
- `torch.optional` support
2024-02-28 12:18:02 -08:00
Sambhav Jain 3cbe6c98ec
Expose `func_name` to the main fx import API (#2949)
As titled.
2024-02-26 10:08:14 -08:00
Stella Laurenzo 89e02c195b
Make a typing dependency that is not in older PyTorch backwards compatible. (#2948)
This was found in a downstream that is pegged to an older PyTorch
version.
2024-02-23 15:52:27 -08:00
Aart Bik 4147b280ce
[torch-mlir][sparse] add block sparsity to mlir lowering (#2942)
Also note that we are in the process of proposing SparseTensorMetadata
to PyTorch FX graph export (see
https://github.com/pytorch/pytorch/pull/117907). This will hopefully
eventually replace the current data structures in torch-mlir.
2024-02-23 11:57:20 -08:00
Rob Suderman 53f6d06ab8
[onnx] Drop `ConstantOfShape` logic form importer, fix torch lowering (#2930)
There is no reason to treat `ConstantOfShape` as a specialized import
any as there exists a onnx-to-torch equivalent. Dropping the import
coding and adding support for resource conversion substantially
increases test coverage for dynamically shaped tests.
2024-02-21 21:34:43 -08:00
Rob Suderman 13553d49c9
[onnx] Update the importer to create a `none` for missing operands (#2931)
Some operands are optional so we require a placeholder for missing
operands. We invent an `onnx.None` operation as our placeholder.
2024-02-20 09:30:30 -08:00
Stella Laurenzo 5253282c55
[fx] Support mutation in ExportedProgram. (#2916)
As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram`
has the long awaited fixes to correctly categorize various things
relating to parameters, buffers, mutated inputs and constants.

With this additional modeling, we are finally able to implement
(safely/soundly) the mutable semantics that were attempted on the
TorchScript path. The difference is that on that path, we had to
conservatively treat everything as mutable and run some dodgy heuristics
(which have been the cause of many bugs relating to
"MaximizeValueSemantics") to try to get back to an immutable state.

The new model supports mutability at the graph edges, allowing both user
inputs and buffers to be mutated (there is some more support than that,
but that is all I fully tracked through to implementation).

Therefore, when we receive programs like this, we now can selectively
enable mutation at the edges. This happens to be the mutability model
that IREE supports, which I expect to be a primary beneficiary. However,
there is nothing stopping anyone else from handling the `!torch.tensor`
types and the existing copy/overwrite ops that will be selectively
added.

Since this relies on API changes that will not release until 2.3, I'm
being a bit cautious about not refactoring existing facilities.
2024-02-16 09:46:30 -08:00
Rob Suderman 074f112d6a
[onnx] Add testing using the `onnx` compilation using torch tests (#2795)
We can route the torch tests via `onnx` using the `torch.onnx.export`
tooling. We can then reimport, lower to torch, and compile to linalg to
validate the onnx path is working correctly.

The current implementation exposes some failures in the `onnx` path so
we cannot enable the onnx test suite yet due to segmentation faults.
2024-02-15 10:17:13 -08:00
saienduri 8e2e5eeae9
add support for decomposition (#2879)
This commit adds decomposition support into the core aten operators
before importing the module from torch.

Also, this commit deals with the lifted tensor constants in
torch.export.export(). We don't want to add unnecessary placeholder
nodes in the graph (extra args in the block module), and should treat
them like the constants that they are. The unnecessary clone is also
removed for max efficiency.
2024-02-14 21:00:52 -08:00
Daniel Garvey 77b7550997
Add support for bfloat16 in fximporter (#2896)
this introduces an additional soft dependency on the python ml_dtypes
python packages in order to support bfloat16

Addresses #2843
2024-02-14 16:24:25 -06:00
Sambhav Jain 3e836d8dad
[fx_importer] Convert non-persistent buffers lifted as tensor constants (#2902)
The investigation is largely recorded in
https://github.com/llvm/torch-mlir/pull/2881, but this change allows us
to capture non-persistent buffers that were lifted as tensor constants
(after https://github.com/pytorch/pytorch/pull/118969 landed in upstream
PyTorch), and propagate them to `Torch` dialect as "frozen"
`torch.vtensor.literal`. I believe this patch should work with both
nightly and stable PyTorch, but will let CI confirm the same. Thanks
@stellaraccident for the valuable pointers and guidance.

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-13 12:38:32 -08:00
Aart Bik b6f4ca512e
[torch-mlir][sparse] sparsity metadata refinement (#2901)
Various improvements on sparsity metadata:

(1) define single data structure for all sparsity related metadata 
(2) handle batched dense dimensions, as well as dense subtensor
dimensions
(3) refine sparsity propagation for deeper networks
2024-02-12 16:10:57 -08:00
Aart Bik be8375d350
[torch-mlir][sparse] implement first sparse_jit end-to-end path (#2894)
This PR introduces a sparse_jit wrapper that can run simple models with
sparse tensor inputs end-to-end. The implementation shows all required
components on modifying sparse tensor types with a 1:N relation on the
call sites. Two tests shows that the JIT runs end-to-end while computing
the correct results.

More details to follow (generalizing to COO and different ranks, as well
as support for *output* sparse tensors), but the general concepts are
all here now.

**_Update: Thanks to Rob, bump to proper LLVM/MLIR hash is done!_**

_**NOTE that all parameter passing changes are nicely done "downstream"
in MLIR, so very little changes are required in torch-mlir code
proper**_

---------

Co-authored-by: Franz Haniel <77495327+frafranz@users.noreply.github.com>
Co-authored-by: Franz Haniel <franz.haniel@amd.com>
2024-02-12 10:04:54 -08:00
Daniel Garvey faf7d4aaa5
[fx_importer] Add support for 0D tensors (#2870)
Adds an escape hatch from creating a DenseResourceElementsAttr for
single value tensors into DenseElementsAttr.

For 0d or 1element, splats are better as DenseElementsAttr. Don't use
DenseResourceElementsAttr for it
2024-02-06 00:19:31 -06:00
Rob Suderman 54e258792c
[onnx] Import `onnx` constants as `onnx.Constant` instead of literals (#2831)
To handle the conversion from raw bytes to `DenseElementsAttr` we need
to handle the endianness conversion during `torch-onnx-to-torch`.
Therefore when importing `onnx.Constant` it is better to represent using
the `onnx` constant operation so that only one location requires the
endianness correction.
2024-01-31 11:41:06 -08:00
Aart Bik 105aad6f57
[torch-mlir] provide FX traced graph importer for sparse tensors (#2817)
Note that we are waiting for actual FX traced graph support for sparse
tensors. For details see

https://github.com/pytorch/pytorch/issues/117188

Until then, however, we provide this clever importer that builds the FX
traced graph for for the dense case and then puts a sparse annotation
back on the parameters.

With import test.
2024-01-30 21:22:12 -08:00