Commit Graph

1563 Commits (4555629246ab56a5588dbd1f0676830484c65a40)

Author SHA1 Message Date
ptrifunovic98 4555629246
Implement lowering of torch.aten.kthvalue (#3360)
Closes
[nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620)
2024-06-15 11:18:39 +05:30
Manupa Karunaratne d2b663ece7
Add onnx op LRN lowering (#3432)
This commit adds support for lowering
Onnx LRN op to aten.
2024-06-14 16:44:43 +00:00
Arham Khan 09c988046c
[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op (#3380)
This implements the Onnx.NegativeLogLikelihoodLoss op using the
signature provided
[here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html)
by replacing it with a `NLLLossForward` op.

Additionally, I included a helper function `get_loss_reduction_enum` to
convert from a string `reduction` parameter to the corresponding
intended integer value since this is an operation that will be reused
for any loss function module. This differs from `get_reduction_enum` in
`TorchUpstream.cpp` which handles the `reduce` parameter from
`scatter_reduce` type operations.
2024-06-14 22:01:11 +05:30
Vivek Khandelwal 2ea2bc3948
[ONNX] Add OnnxToTorch Lowering for GroupNormalization op (#3458)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-14 16:18:53 +00:00
Umang Yadav 04c6479350
[ONNX] Add onnx parser for LpPool operator (#3449)
Similar to https://github.com/llvm/torch-mlir/pull/3435

Solves https://github.com/nod-ai/SHARK-Turbine/issues/728
2024-06-14 21:41:18 +05:30
Xinyu Yang 6f94c7b0aa
[Torch] Add support for Meshgrid (#3462) 2024-06-14 23:59:08 +08:00
Phaneesh Barwaria 919b599ebe
onnx.MaxPool add atenMaxPool1d lowering support (#3452)
fixes #3422
2024-06-13 15:37:11 +05:30
Vinayak Dev 39d882f7c9
[torch] Add OnnxToTorch lowering for the Col2Im op (#3424)
Adds OnnxToTorch lowering for the `onnx.Col2Im` op.
2024-06-13 08:42:06 +00:00
Surya Jasper de7f058a0e
[MLIR][ONNX] Add OnnxToTorch support for MaxRoiPool Op (#3395)
This PR adds OnnxToTorch support for MaxRoiPool op
2024-06-13 10:46:14 +05:30
Umang Yadav 9b76a2e3eb
[ONNX] add onnx lowering for global lp pool operator (#3435)
Solves https://github.com/nod-ai/SHARK-Turbine/issues/727

Uses AvgPool to implement GlobalLpPool similar to this
https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_lp_pool.py

cc: @vivekkhandelwal1
2024-06-13 10:37:08 +05:30
Lei Zhang 77d7f64472
Update to llvm/llvm-proect@27ac46e6be (2024-6-12) (#3454)
This would require to bump stablehlo at the same time.
2024-06-12 19:34:01 -07:00
Chi_Liu ae6f5e8251
[ONNX] Fix AveragePool attributes support (#3235)
Issues was found here https://github.com/nod-ai/SHARK-Turbine/issues/643
    - [ONNX] Fix padding attributes for onnx.AveragePool
    - [Linalg] Add countIncludePad false support for AtenAvgPool1/2dOp
    - [Linalg] Add an avg_pool2d countIncludePad False e2e tests
    - [Linalg] Fix conflict with AtenAvgPool3dOp
    - [Linalg] Fix e2e crash with AtenAvgPool1dOp
    - [Linalg] Add dynamic dim support for AtenAvgPool2dOp
    - [Linalg] Fix AvgPool2dDivisorOverrideModule crash
2024-06-12 12:16:43 -07:00
Suraj Sudhir 41d04a8995
[onnx] Resize supports default-valued attributes (#3450)
Handles onnx exporters emitting default-valued attributes.

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-06-12 09:23:42 -07:00
zjgarvey de28c8540b
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
2024-06-12 10:37:22 +05:30
zjgarvey 7cd3368b20
[ONNX] Fix resize ceil numerics and add half_pixel_symmetric support (#3443)
This patch fixes several failing tests in our [external test
suite](https://github.com/nod-ai/SHARK-TestSuite/tree/main/iree_tests/onnx/node/generated),
and addresses some of the issues discussed in #3420
2024-06-11 22:35:50 -05:00
Matthias Gehre e07a0bfc54
onnx.resize: Add support for coordTfMode "half_pixel" (#3441)
half_pixel is also the default mode used by ONNX, see
https://onnx.ai/onnx/operators/onnx__Resize.html
2024-06-10 20:59:29 +02:00
Aart Bik d77bab37d1
[torch-mlir][sparse] re-enable all sparse tests (#3444)
this fixes the following issue:

https://github.com/llvm/torch-mlir/issues/3418
2024-06-10 11:19:32 -07:00
Vivek Khandelwal 5bc626465b
[ONNX] Lower Onnx.Concat lowering version (#3437)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-09 12:07:20 +05:30
Vivek Khandelwal d35b6b412a
[ONNX] Add OnnxToTorch Lowering for Sequence Ops (#3425)
This commit adds the lowering for SequenceAt, SequenceEmpty,
SequenceInsert, SequenceErase op

Signed-Off By: Vivek Khandelwal<vivekkhandelwal1424@gmail.com>
2024-06-08 09:58:11 +05:30
Yuanqiang Liu 689efc8917
[Torch] fix toBuiltinTensor() (#3415)
* Let `toBuiltinTensor()` reflects the original dtype of
`!torch.vtensor`.
* Backend handles dtype conversion themselves.
2024-06-08 09:36:32 +08:00
Rob Suderman 75af64fc12
[torch] Add support for f8 types for linalg conversion (#3436)
Linalg conversion requires mapping for f8 types
2024-06-07 13:59:38 -07:00
aldesilv f794582b18
add resize nearest mode round_prefer_floor, round_prefer_ceil, ceil (#3421) 2024-06-07 14:04:11 -05:00
Vivek Khandelwal 1a9c0a35a9
[Onnx] Add Onnx->Torch lowering for Onnx.Shrink Op (#3385)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-07 22:47:27 +05:30
Suraj Sudhir 1c2778dd56
[ONNX] Conv op adds support for asymmetric padding. (#3426)
Supports asymmetric padding by performing a torch.nn.functional.pad on
the input before performing the convolution.

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-06-07 09:54:39 -07:00
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Xinyu Yang 431d98b405
[Stablehlo] Add lowering of GridSampler Op (#3084)
Inspired by PyTorch decompositions.py.
See
ec58f1f74e/torch/_decomp/decompositions.py (L3923-L4086)
Only support paddingMode=0 or 1 and interpolationMode=0 or 1
2024-06-07 16:06:07 +08:00
Vivek Khandelwal 72837fbb3d
build: manually update PyTorch version (#3340)
Set PyTorch and TorchVision version to nightly release 2024-05-14.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-06 22:23:40 +05:30
penguin_wwy d59d0b6e5a
[Linalg] Promote type for compare tensor op (#3416) 2024-06-04 16:05:39 -07:00
Vivek Khandelwal 661be2d5b0
[MLIR][Torch] Add TorchToLinalg lowering for AtenAvgPool3dOp (#3030)
This commit also fixes the average pool op' test failing for
OnnxToLinalg lowering.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-04 22:12:34 +05:30
Vivek Khandelwal 35dd8c52cd
[ONNX] Add OnnxToTorch Lowering for MaxUnpool op (#3413)
This commit also adds the Torch declaration for aten.max_unpool2d and
aten.max_unpool3d op. The TorchToLinalg lowering for the same will be
added in a follow-up commit.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-04 21:09:53 +05:30
Yuanqiang Liu 50f7103098
[Stablehlo] support uint8 (#3367)
Support lowering unsigned integer type to stablehlo as discussed in
https://github.com/llvm/torch-mlir/pull/2184.

The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
    %0 = tensor.empty() : tensor<3xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %2 = arith.uitofp %in : i8 to f32
      linalg.yield %2 : f32
    } -> tensor<3xf32>
    return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
    %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
    return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
2024-06-04 09:04:59 +08:00
zjgarvey 56d21cba62
Link necessary op interface implementations (#3364)
This patch adds two `memref` passes to `torch-mlir-opt`, which already
occur in the pass pipeline
`torch-backend-to-linalg-on-tensors-backend-pipeline`. Additionally,
necessary op interface external models are included to address issue
#3352.
2024-06-03 19:43:28 -05:00
zjgarvey 8995c90879
[TorchToLinalg] add support for quantized group conv (#3341)
This addresses 7 of the model failures I'm seeing in the test suite. See
[Shark-Turbine issue
#566](https://github.com/nod-ai/SHARK-Turbine/issues/566).

Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream
before merging this. See [llvm-project PR #92136
](https://github.com/llvm/llvm-project/pull/92136).

A small additional expansion to operand quantization is included in this
patch to address a model failure that occurs when unblocking the
quantized group convolutions in one of these onnx models.
2024-06-03 21:57:44 +05:30
Vivek Khandelwal 6382dbbcc0
[ONNX] Add OnnxToTorch lowering for SpaceToDepth op (#3393)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-03 20:29:39 +05:30
Xinyu Yang 285b087a5d
[Torch] Emit rrelu and decompose it (#3250)
as title
2024-06-03 19:25:52 +08:00
Xinyu Yang 267052df2a
[Torch] decompose AtenLerpTensorOp (#3251)
as title
2024-06-03 15:25:09 +08:00
Xinyu Yang 23b53050de
[Torch]Support conv_transpose1d and conv_transpose3d (#3286)
1. Support conv_transpose1d and conv_transpose3d
2. Fix bugs of convertTransposedConv func in
lib/Conversion/TorchToStablehlo/Linear.cpp
2024-06-03 15:11:12 +08:00
Rob Suderman 617b00b983
[NFC] Fix member cast change to global for landing collision (#3407)
A PR landed when moving away from a deprecated cast function. Updated
the corresponding lines to pass.
2024-05-31 17:31:24 +00:00
zjgarvey 8952377603
[Onnx] reduce MatMul OpsetVersion to 1 (#3403)
Resolves #3324
2024-05-31 22:17:56 +05:30
Surya Jasper fc100a117d
[MLIR][ONNX] Add OnnxToTorch support for Scatter Op (#3400)
This PR adds OnnxToTorch support for Scatter op
2024-05-31 07:36:48 +00:00
Rob Suderman afca88a058
[NFC] Change to *cast instead of .*cast variants (#3405)
Member casts have been deprecated. Changing over a bunch of the member
cast calls to the global templated variants to remove deprecation
warnings.
2024-05-30 23:45:13 -07:00
Yuanqiang Liu 4e05e2cd1e
[Torch] support recompose of aten.split.with_sizes and aten.tensor_sp… (#3401)
…lit.sections

* support recompose to aten.split.with_sizes and
aten.tensor_split.sections
* fix recompose of aten.chunk
2024-05-31 09:56:47 +08:00
zjgarvey 074098d20c
Modifies onnx resize lowering to fix numerical issues (#3381)
Updates:

- some unsupported modes are now going to report a match failure for
unsupported coordinate transformation modes.
- fixes a bug that was introduced in the last patch for resize (my
bad...)
- uses actual x and y coordinates for computing weights in bilinear
interpolation (rather than eps modified values)
- slightly simplifies the bilinear interpolation payload for readability
and performance
- passes coordinate transformation mode information from an onnx.Resize
op to the mode string for the aten._interpolate op. This allows us to
perform custom logic in the torch->linalg lowering to support
onnx.Resize options without losing the default behaviors of the
interpolate op.
2024-05-30 20:34:37 -04:00
Vivek Khandelwal d7b8f00d01
[ONNX] Add OnnxToTorch Lowering for LpNormalization op (#3397)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-05-30 23:05:26 +05:30
penguin_wwy e4be197efd
[FxImporter] Fix transpose rank zero (#3382) 2024-05-30 14:31:18 +08:00
penguin_wwy 1f544c37d0
[NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
Xida Ren (Cedar) 23d2d66a59
Fix error when attempting to read elided onnx constants (#3398)
Co-authored-by: zjgarvey <zjgarvey@gmail.com>
2024-05-29 16:56:23 -07:00
Yuanqiang Liu e0a5adb1db
[Torch] fix aten.linear's decomposition (#3391)
* support aten.linear with more rank.
2024-05-27 15:49:50 +08:00
Yuanqiang Liu 28aeb047c1
[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic (#3389) 2024-05-26 12:34:56 +08:00
zjgarvey 27169dcda9
Replace some depreciated uses of cast (#3343)
Contributing towards #3299
2024-05-23 09:01:47 -07:00