Commit Graph

300 Commits (51902ec2dc6df99a87e0fee092e59b492ab04837)

Author SHA1 Message Date
Andrea 🦈 51902ec2dc
Create MLIR functions for ONNX operators that are functions (#3409)
Resolves #3384.

Many ONNX operators are defined by functions and therefore could be
expanded into simpler ONNX operations during importing, avoiding the
need for tools downstream to support these operators directly.

This commit adds this capability to onnx_importer.py. When importing a
node, the schema for the node's operator is retrieved. If the schema
provides a function for the operator, a specialized version for the
node's types and attributes will be created and imported as an MLIR
function with private visibility. An MLIR function call will then be
emitted, instead of a normal operator node. Caching is used to avoid
generating redundant functions within the same module.

In order to avoid a disruptive change to the importer output for a
large number of operators that already have TorchOnnxToTorch support,
an allowlist strategy is used by default. With this commit, only one
operator is allowlisted for expansion, MeanVarianceNormalization.
However, many other operators can be correctly expanded by the current
code, so hopefully the allowlist can be gradually extended. It is
possible to disable the allowlist in the configuration, in which case
all functions are expanded (useful for testing).

Tools downstream of the importer may now need to do inlining when
consuming the output of the importer, e.g.:

  cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch

Explanations for subtle code changes:

- Looking up the correct schema and function for an operator requires
  knowing the opset version. NodeImporter retrieves this from the
  opset imports on the ModelProto retained by the GraphInfo. Previously,
  the model_proto field on GraphInfo was None when importing a subgraph
  in import_regions, but this conflicts with the new need for opset
  version info. Since the apparent purpose of setting it to None was to
  control how GraphInfo generates its input map, a new flag is added to
  GraphInfo (is_subgraph) to control this behavior, so that the actual
  ModelProto can now be provided without breaking this. This also turned
  out to be useful for getting the Config via ModelInfo via GraphInfo.
- Some operators' functions are context-dependent, which means the
  function definition depends on the types of the inputs. Therefore node
  importing now needs to look up the types of a node's inputs, not just
  its outputs as was the case previously. Consequently the operand to
  find_type_proto_for_name() may now be a graph input or initializer in
  some cases, so it has to be updated.
2024-06-14 10:11:26 -07:00
Xinyu Yang 6f94c7b0aa
[Torch] Add support for Meshgrid (#3462) 2024-06-14 23:59:08 +08:00
Wu Yuan a02e14e971
[FxImporter] Add aten._scaled_dot_product_flash_attention_for_cpu to default decomposition table (#3456) 2024-06-14 10:52:09 +08:00
Phaneesh Barwaria 919b599ebe
onnx.MaxPool add atenMaxPool1d lowering support (#3452)
fixes #3422
2024-06-13 15:37:11 +05:30
Vinayak Dev 39d882f7c9
[torch] Add OnnxToTorch lowering for the Col2Im op (#3424)
Adds OnnxToTorch lowering for the `onnx.Col2Im` op.
2024-06-13 08:42:06 +00:00
Chi_Liu ae6f5e8251
[ONNX] Fix AveragePool attributes support (#3235)
Issues was found here https://github.com/nod-ai/SHARK-Turbine/issues/643
    - [ONNX] Fix padding attributes for onnx.AveragePool
    - [Linalg] Add countIncludePad false support for AtenAvgPool1/2dOp
    - [Linalg] Add an avg_pool2d countIncludePad False e2e tests
    - [Linalg] Fix conflict with AtenAvgPool3dOp
    - [Linalg] Fix e2e crash with AtenAvgPool1dOp
    - [Linalg] Add dynamic dim support for AtenAvgPool2dOp
    - [Linalg] Fix AvgPool2dDivisorOverrideModule crash
2024-06-12 12:16:43 -07:00
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Xinyu Yang 431d98b405
[Stablehlo] Add lowering of GridSampler Op (#3084)
Inspired by PyTorch decompositions.py.
See
ec58f1f74e/torch/_decomp/decompositions.py (L3923-L4086)
Only support paddingMode=0 or 1 and interpolationMode=0 or 1
2024-06-07 16:06:07 +08:00
Vivek Khandelwal 72837fbb3d
build: manually update PyTorch version (#3340)
Set PyTorch and TorchVision version to nightly release 2024-05-14.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-06 22:23:40 +05:30
penguin_wwy d59d0b6e5a
[Linalg] Promote type for compare tensor op (#3416) 2024-06-04 16:05:39 -07:00
Vivek Khandelwal 661be2d5b0
[MLIR][Torch] Add TorchToLinalg lowering for AtenAvgPool3dOp (#3030)
This commit also fixes the average pool op' test failing for
OnnxToLinalg lowering.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-04 22:12:34 +05:30
Vivek Khandelwal 35dd8c52cd
[ONNX] Add OnnxToTorch Lowering for MaxUnpool op (#3413)
This commit also adds the Torch declaration for aten.max_unpool2d and
aten.max_unpool3d op. The TorchToLinalg lowering for the same will be
added in a follow-up commit.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-04 21:09:53 +05:30
Yuanqiang Liu 50f7103098
[Stablehlo] support uint8 (#3367)
Support lowering unsigned integer type to stablehlo as discussed in
https://github.com/llvm/torch-mlir/pull/2184.

The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
    %0 = tensor.empty() : tensor<3xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %2 = arith.uitofp %in : i8 to f32
      linalg.yield %2 : f32
    } -> tensor<3xf32>
    return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
    %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
    return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
2024-06-04 09:04:59 +08:00
zjgarvey 8995c90879
[TorchToLinalg] add support for quantized group conv (#3341)
This addresses 7 of the model failures I'm seeing in the test suite. See
[Shark-Turbine issue
#566](https://github.com/nod-ai/SHARK-Turbine/issues/566).

Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream
before merging this. See [llvm-project PR #92136
](https://github.com/llvm/llvm-project/pull/92136).

A small additional expansion to operand quantization is included in this
patch to address a model failure that occurs when unblocking the
quantized group convolutions in one of these onnx models.
2024-06-03 21:57:44 +05:30
Xinyu Yang 285b087a5d
[Torch] Emit rrelu and decompose it (#3250)
as title
2024-06-03 19:25:52 +08:00
Xinyu Yang 267052df2a
[Torch] decompose AtenLerpTensorOp (#3251)
as title
2024-06-03 15:25:09 +08:00
Xinyu Yang 23b53050de
[Torch]Support conv_transpose1d and conv_transpose3d (#3286)
1. Support conv_transpose1d and conv_transpose3d
2. Fix bugs of convertTransposedConv func in
lib/Conversion/TorchToStablehlo/Linear.cpp
2024-06-03 15:11:12 +08:00
zjgarvey 878ba72c65
Bump LLVM to llvm/llvm-project@6127f15 (#3396)
Signed-off-by: zjgarvey <zjgarvey@gmail.com>
2024-05-31 17:49:20 +01:00
Yuanqiang Liu 4e05e2cd1e
[Torch] support recompose of aten.split.with_sizes and aten.tensor_sp… (#3401)
…lit.sections

* support recompose to aten.split.with_sizes and
aten.tensor_split.sections
* fix recompose of aten.chunk
2024-05-31 09:56:47 +08:00
zjgarvey 074098d20c
Modifies onnx resize lowering to fix numerical issues (#3381)
Updates:

- some unsupported modes are now going to report a match failure for
unsupported coordinate transformation modes.
- fixes a bug that was introduced in the last patch for resize (my
bad...)
- uses actual x and y coordinates for computing weights in bilinear
interpolation (rather than eps modified values)
- slightly simplifies the bilinear interpolation payload for readability
and performance
- passes coordinate transformation mode information from an onnx.Resize
op to the mode string for the aten._interpolate op. This allows us to
perform custom logic in the torch->linalg lowering to support
onnx.Resize options without losing the default behaviors of the
interpolate op.
2024-05-30 20:34:37 -04:00
penguin_wwy e4be197efd
[FxImporter] Fix transpose rank zero (#3382) 2024-05-30 14:31:18 +08:00
penguin_wwy a5d3b546f8
[FxImporter] Fix embedding bag (#3387) 2024-05-29 14:46:21 +08:00
Yuanqiang Liu e0a5adb1db
[Torch] fix aten.linear's decomposition (#3391)
* support aten.linear with more rank.
2024-05-27 15:49:50 +08:00
Yuanqiang Liu 05929f9171
enhance verbose option in e2e_testing (#3390)
so that `python3 e2e_testing/main.py -v` would print intermediate IR.
2024-05-27 08:01:07 +08:00
Yuanqiang Liu 28aeb047c1
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389) 2024-05-26 12:34:56 +08:00
Yuanqiang Liu 5bb1a65ec9
[Stablehlo] refactor reduction lowering and support aten.amin (#3383)
* implement detailed lowering template pattern
`ConvertAtenReduceAllDimsOp` and `ConvertAtenReduceKeepDimOp`
* support `aten.amin`'s lowering.
2024-05-23 20:40:20 +08:00
penguin_wwy d924d0047f
[FxImporter] Fix primitive type in return (#3379) 2024-05-23 09:55:33 +08:00
Yuanqiang Liu f4bfe3f948
Bump llvm and stablehlo (#3377)
* bump llvm to 1e5f29af81a5f6fda308074f6345b9fba4faa71c
* bump stablehlo to c44d9af8d4879adccf1054cb61a53377ae5898cb
2024-05-22 23:28:45 +08:00
penguin_wwy 972d47b586
[FxImporter] Fix constant bool tensor (#3375) 2024-05-22 22:59:01 +08:00
penguin_wwy c2c1c2cfa4
[FxImporter] Fix failed e2e case (#3365) 2024-05-22 00:20:54 +08:00
Vivek Khandelwal b870729efe
[torch] Fix `onnx.MaxPool` lowering (#3133)
This commit fixes the onnx.MaxPool op lowering which was lacking the
indices result support.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-05-21 21:05:32 +05:30
Wu Yuan cc28d566ff
[Stablehlo] Support AtenTrilOp (#3359)
1. lower aten.tril to stablehlo composed by iota, select and so forth
2. add related e2e test cases
2024-05-20 15:49:24 +08:00
Yuanqiang Liu 8814d0ae64
[Torch] emit aten.dot and canonicalize it to aten.matmul (#3361)
* canonicalize `aten.dot` to `aten.matmul`
2024-05-18 22:45:14 +08:00
zjgarvey 6cba93b16e
[ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate lowering (#3351)
Addresses [Shark-Turbine
#196](https://github.com/nod-ai/SHARK-TestSuite/issues/196)

Related tracker [Shark-Turbine
#566](https://github.com/nod-ai/SHARK-Turbine/issues/566)

Related onnx.Resize issues [Shark-Turbine
#616](https://github.com/nod-ai/SHARK-Turbine/issues/616)
2024-05-17 12:18:57 -07:00
Suraj Sudhir cba91a9b96
[ONNX][TOSA] Adds ONNX to TOSA e2e tests (#3358)
- Refactors OnnxBackend to be generic and consume any Torch backend.

---------

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-05-16 21:44:26 -07:00
Xinyu Yang 7faba75696
[Torch] Decompose AtenMaskedScatterOp (#3353)
Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>
2024-05-16 15:27:25 +08:00
Suraj Sudhir 0ca88028cd
[FxImporter][TOSA] Enable FxImporter to TOSA e2e tests (#3349)
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-05-15 14:37:30 -07:00
Peiming Liu ccb772cd0f
[sparse] propagate sparsity properly when decompose torch operations. (#3318) 2024-05-15 10:09:27 -07:00
Yuanqiang Liu 5928f68e60
[Stablehlo] refactor amax, max, max.dim's lowering to stablehlo (#3348)
* not to decompose `aten.amax` on `stablehlo` backend. Because it could
be lowering to `stablehlo.reduce` directly.
* lowering `aten.max.dim` to `stablehlo.reduce apply max` when
`AtenMaxDimOp.getIndices()` doesn't have users. It's more simple.
2024-05-16 00:05:19 +08:00
penguin_wwy 20d4d16d32
[FxImporter] Add an e2e test example for FxImporter (#3331) 2024-05-14 00:45:19 +08:00
Andreas Falkenberg adafd51823
[onnx] Gridsampler addition of nearest mode (#3320)
Added nearest neighbor selection for onnx.Gridsampler
2024-05-10 11:42:10 -07:00
NeverRaR 1d4859699b
MaxPool1d lowering to linalg (#3295)
Co-authored-by: root <root@i32b01216.sqa.eu95>
2024-05-10 22:05:26 +05:30
penguin_wwy be20db0a0e
[NFC] Delete the deprecated example cases (#3323) 2024-05-11 00:28:58 +08:00
Vivek Khandelwal 10db310460
build: manually update PyTorch version (#3291)
Set PyTorch and TorchVision version to nightly release 2024-05-05.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-05-10 21:45:06 +05:30
penguin_wwy afe87d62b4
[Linalg] [Stablehlo] Promote type for compare scalar op (#3306) 2024-05-10 02:20:06 +08:00
Peiming Liu cff144b3ac
[sparse] fix double free due to incompatibility between buffer-deallo… (#3303)
…cation and sparse tensors.

**NOTE**: This PR _doges_ the issue in buffer-deallocation pass instead
of resolving it. In the future, we need to fix the bug in
buffer-deallocation pass when handling code generated by sparse
compiler.
2024-05-08 21:18:17 -07:00
aldesilv ec6d7aa5d2
OnnxToTorch lowering resize op (#3013)
https://github.com/nod-ai/SHARK-Turbine/issues/358
adds a lowering from onnx to linalg for bilinear and nearest resize with
support for using scales or sizes to get resize shape. uses coordinate
transform half pixel for bilinear mode and asymmetrical for nearest
mode. See
https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added
two passes -- one for bilinear and the other for nearest.
2024-05-08 21:35:03 +00:00
zjgarvey 0abc5868b5
[ONNX] Enables data propogation for onnx shape inference (#3280)
This small change seems to dramatically improve shape inference for
complex models, and consequently, improves onnx importer reliability.
2024-05-08 09:29:23 -07:00
Jiawei Wu 346a536c9f
[Torch Dialect] decompose all index_put-like op to aten.index_put.hacked_twin for stricter semantics (#3071)
This PR decomposes all index_put-like op to aten.index_put.hacked_twin for stricter semantics, i.e., no None index in indices argument.
2024-05-08 22:44:57 +08:00
Xinyu Yang abef114c0c
[torch] emit aten.Softshrink and aten.Hardshrink (#3248)
as title
2024-05-08 15:20:45 +08:00