Commit Graph

1691 Commits (3180704b1470c047faca5fb64d285cda2a287818)

Author SHA1 Message Date
zjgarvey d0933b0eb6
[TorchToLinalg] Fix possible OOB access in Interpolate lowering (#3570)
Following up from the discussion in
<https://github.com/llvm/torch-mlir/pull/3550>, I've edited the lowering
to prevent OOB extracts in a more direct fashion (i.e., just clamping
directly).

I don't think this affects the lit tests at all, but I've tested the
changes in our external test suite at
<https://github.com/nod-ai/SHARK-TestSuite/tree/main/>. I found the
issue when I was unexpectedly getting `nan`'s along the output image
border for a resize test there.
2024-08-02 13:55:37 -05:00
zjgarvey 79ae0afc2f
[TorchToLinalg] Simplify QuantizePerTensor lowering (#3576)
Uses arith::MaximumFOp and arith::MinimumFOp instead of comparison and
select ops to improve readability of IR.
2024-08-02 13:40:52 -05:00
Rob Suderman f7b5c13870
Change linalg.matmul_unsigned to linalg.matmul with unsigned type_fn (#3587)
Change linalg.matmul_unsigned to linalg.matmul with unsigned type_fn

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
2024-08-02 11:32:24 -07:00
Rob Suderman d273bdfabf
[onnx] Fix default `alpha` for `onnx.Elu` (#3583)
We were defaulting to `0.0` for `onnx.Elu` when it is supposed to be
`1.0`.
2024-08-02 09:29:17 -07:00
Rob Suderman 3d33c5a206
[onnx] Fix `onnx.ScatterElements` for negative indices (#3582)
We need to adjust for negative scatter indice values. Added
materializing out the inbounds adjustment.
2024-08-02 09:01:10 -07:00
Rob Suderman 306ed62edd
[onnx][torch] Fix `onnx.SoftmaxCrossEntropyLoss` for ignore index (#3585)
There were two issues related to `ignore_index` being set

(1) the onnx-to-linalg pass as not reading the value correctly (2) the
mean pass was not considering the `ignore_index` value

For (2) when taking the mean we need to know how many of the values were
considered in the sum and therefore we cannot divide by the total number
of elements. Adding a summation across the total number should correct
this issue.
2024-08-02 09:00:56 -07:00
yyp0 22cd4441e7
[Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566)
The static uneven divisible AdaptiveAvgPool2d means that although the
input size is not an integer multiple of ouput size, but the kernel and
stride size can also be fixed (not dynamic). The derivation logic of
kernel and stride size is consistent with
torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the
following:

1. Stride Size
Firstly , derive the start index in each reduce operation according to
the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size)
// output_size`. For each index `k`, if `k * (input_size % output_size)
< output_size`, then the current and previous stride keeps the same as
`input_size // output_size`. So suppose `(n-1) * (input_size %
output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d
process keeps static, as `input_size // output_size`.

2. Kernel Size
torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static
kernel size when the input/output sizes satisfy either of the two
conditions, `input_size % output_size == 0` or `output_size %
(input_size % output_size) == 0`. Here if `input_size % output_size ==
0`, then the kernel size equals `input_size // output_size`, otherwise
`input_size // output_size + 1.`
2024-08-01 11:37:53 +08:00
Jiawei Wu edc87fc577
[stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops (#3322)
For now, at most one dynamic dim of index tensors in
aten.index/aten.index_put-like op is supported.
2024-08-01 10:41:09 +08:00
Rob Suderman 7f475e174e
Add extf-trunc f32-f64-f32 ellision (#3579)
Torch has all scalars represented as i64 and f64 types which results in
extraneous trunc-extf commands. We can rework this by elliding
widen-narrow cases away.
2024-07-31 16:50:00 -07:00
Jiawei Wu 7b2902f6e2
[stablehlo]: fix aten.index_put_hacked_twin lowering to StableHlo (#3572)
Current StableHlo lowering strategy works well when `src` tensor's rank
is no bigger than `dst` tensor's. The new patch make it succeed in other
cases. The following is an example.
```
%190 = torch.prim.ListConstruct %arg4 : (!torch.vtensor<[1,1024],si64>) -> !torch.list<vtensor>
%191 = torch.aten.index_put.hacked_twin %189, %190, %186, %true : !torch.vtensor<[1024,768],f32>, !torch.list<vtensor>, !torch.vtensor<[1,1024,768],f32>, !torch.bool -> !torch.vtensor<[1024,768],f32>
```
2024-07-31 22:33:57 +08:00
yyp0 f49b9c14f1
[Torch] Add support for Aten__Or__BoolOp (#3574) 2024-07-31 17:23:53 +08:00
Suraj Sudhir d3efab984b
[TOSA] Fix Tensor.hacked_twin to support diff size indexes (#3547)
- Broadcasts index list tensors
- Adds torch.nn.Unfold test

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-07-30 14:32:05 -07:00
Ivan Butygin 8bd1b9751f
`max_unpool3d` linalg lowering (#3536)
An attempt of  `aten.max_unpool3d` to linalg lowering.
There are known issues with this implementation (see comment in code).
2024-07-30 20:59:17 +03:00
zjgarvey f1c74e1431
[TorchToLinalg] add support for depthwise qconv (#3564)
- Adds support for lowering depthwise + quantized convolution ops to
linalg::DepthwiseConv2DNhwcHwcQOp
- Changed the variable name for groupSize (which is really C/G) to the
more appropriate numGroups (G).
- Discovered in e2e testing that linalg does not accept (Cin = groups &&
Cout = K*groups for K>1) as a "depthwise" conv, so this also updates the
case-checking to reflect this issue.
2024-07-29 12:25:07 -07:00
zjgarvey 50d6ce225f
Align Quantization Rounding Scheme with ONNX/Pytorch (#3569)
Pytorch and ONNX apparently round to nearest, ties go to nearest even,
but we were using `math::round` for the torch-to-linalg conversion of
`quantize_per_tensor`, which rounds away from zero on ties.
2024-07-29 12:24:46 -07:00
Vinayak Dev 30c4d2f2b8
[torch] Add OnnxToTorch lowering for Onnx.Unique op (#3523)
Adds OnnxToTorch Lowering for the `Onnx.Unique` op.
2024-07-29 17:32:44 +05:30
pdhirajkumarprasad a211ccbcff
Implementation of SplitToSequence ops lowering (#3509)
Added support for splitToSequence ops lowering
Added test case with filecheck
2024-07-29 15:44:22 +05:30
Vivek Khandelwal b6e4725259
[ONNX] Add OnnxToTorch lowering for NonMaxSuppression op (#3501)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-07-26 21:01:27 +05:30
yyp0 ea60d72489
[Torch] Add AtenMaskedFillTensorOp support (#3561) 2024-07-26 15:32:13 +08:00
Vivek Khandelwal 15cf7106c4
[ONNX] Reduce Onnx.Flatten op version (#3560)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-07-24 21:27:20 +05:30
Yuanqiang Liu 003b06dfa1
[Torch] enhance naryFolderHelper to support mixed dtypes (#3559)
* so that it could support like `i64 + f64 => f64`.
* also unify `aten.log`'s folder code to use `naryFolderHelper`.
2024-07-24 17:54:59 +08:00
Yuanqiang Liu aad1604046
[Torch] enhance fold of aten.squeeze.dim (#3558) 2024-07-24 14:13:48 +08:00
Ze Zhang d1e172f418
Register fake_quantize_cachemask ops and add their decompose patterns (#3556)
Test:

`cmake --build build --target check-torch-mlir-all`
2024-07-23 11:33:12 -07:00
Yuanqiang Liu 21ad890009
[Torch] enhance fold of aten.slice.Tensor (#3557)
so that it could support folding slice with any static shape.
2024-07-23 22:53:03 +08:00
Yuanqiang Liu 78846425e2
[Torch] add constriants when decompose aten.split_with_sizes (#3555) 2024-07-23 10:34:29 +08:00
Vivek Khandelwal 22c9008bb9
build: Update Roll PyTorch version (#3548)
This commit also updates the PyTorch and Torchvision nightly links since
they are now moved to a different location.

PyTorch Nightly: https://download.pytorch.org/whl/nightly/cpu/torch/
Torchvision Nightly:
https://download.pytorch.org/whl/nightly/cpu/torchvision/

Disables dtype checks for some ops, tracked by https://github.com/llvm/torch-mlir/issues/3552

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-07-19 21:38:57 +05:30
bosko-syrmia 2cdf3deae3
implement lowering of torch.aten._linalg_slogdet (#3524) 2024-07-19 11:24:43 +05:30
Branko Trifkovic c7d972ed58
Implement lowering of torch.aten.tril_indices (#3517) 2024-07-18 18:38:12 +05:30
jinchen f0ce1e94ce
[ONNX] Add OnnxToTorch support for SequenceMap (#3535) 2024-07-17 14:25:09 -07:00
pkapris-syrmia fde286f491
Implement lowering for torch.aten.hann_window.periodic (#3502) 2024-07-17 18:21:23 +05:30
pkapris-syrmia b59efc75f3
Implement lowering of torch.aten.atleast_1d (#3498)
This operator is necessary in order to implement torch.aten.vstack.
Which will be added in a future PR.
2024-07-17 18:20:30 +05:30
Arham Khan 574143448b
[E2E][ONNX] torch.multinomial (#3404)
This PR adds a conversion in the TorchOnnxToTorch pass for the ONNX
Multinomial operation. It also adds a TorchToLinalg lowering for the
`aten.Multinomial` op and does a light refactor of some repeated code
that generates random floating point numbers in
`TorchToLinalg/Random.cpp`.
2024-07-16 23:09:39 +05:30
rohan-tan-bhowmik 0791a8860c
[Torch] Implements TorchToLinalg lowering of torch.ops.aten._weight_norm_interface (#3538)
Resolves https://github.com/nod-ai/SHARK-Turbine/issues/757.

Adds TorchToLinalg lowering for `Aten_WeightNormInterfaceOp`.

---------

Co-authored-by: Ubuntu <rbhowmik@RohanBhowmikVM.judsoscro3wupi0qm4bjlj5m3b.bx.internal.cloudapp.net>
2024-07-16 23:09:12 +05:30
Yuanqiang Liu 714270a922
[Stablehlo] legalize deprecated ops to stablehlo ops (#3543) 2024-07-17 00:05:11 +08:00
Xinyu Yang e5d1677894
[Torch] Eliminate getWithLeastStaticInformation in DecomposeAtenLinspaceOp and DecomposeAtenFakeQuantizePerTensorAffineOp (#3539)
as title
2024-07-15 10:02:36 +08:00
Yuanqiang Liu 5e4f00acb1
[Torch] add support for aten.scatter_add (#3534) 2024-07-12 09:15:42 +08:00
zjgarvey 0fb8b017d8
Adds misc fixes for some padding related issues (#3528)
This patch adds a few misc pad op related changes:

1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4

Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
2024-07-11 20:01:45 -05:00
Yuanqiang Liu b38585e077
[Torch Dialect] fix aten.nan_to_num's decomposition when inf=None (#3530)
also add shape infer in decomposition, see
https://github.com/llvm/torch-mlir/issues/3312
2024-07-11 08:46:40 +08:00
Xida Ren (Cedar) 5342aa70cf
Support onnx.GRU and onnx.RNN (#3447) 2024-07-10 14:04:17 -04:00
Yuanqiang Liu 5bee9aac63
[Stablehlo] simplify promoteType (#3525)
only provide `outElementType` when promoteType
2024-07-10 10:52:19 +08:00
zjgarvey dcb48dd46c
[ONNX] Fix LpNormalization Lowering (#3521)
The LpNormalization lowering was previously just computing the norm,
which is incorrect. This computes the norm then divides the input tensor
by it's norm.

I've tested this against some simple onnx models locally. I'll look into
adding a test case for this in an external test suite.
2024-07-09 15:42:26 -05:00
Gaurav Shukla 0b46d1110a
[MLIR][ONNX] Add support for onnx.ScatterND (#3479)
This commit adds support for onnx.ScatterND op in the onnx pipeline.

Signed-off-by: Gaurav Shukla <gaurav.shukla@amd.com>
2024-07-08 13:27:14 +05:30
Matthias Gehre 6ea6a6c2fe
TorchOnnxToTorch: Fix stack-use-after-free (#3480)
We used to move the SmallVector into an ArrayRef and then the
SmallVector left the scope.

Found by asan.
2024-07-08 09:20:09 +02:00
Yuanqiang Liu 3225f20ab1
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %0 = arith.index_cast %dim : index to i64
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %1 = arith.index_cast %dim_0 : index to i64
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %2 = arith.index_cast %dim_1 : index to i64
    %from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
    %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
    %4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %4 : tensor<?x?x?xf32>
  }
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
  func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
    %dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
    %from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
    %0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
    %1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    return %1 : tensor<?x?x?xf32>
  }
}
```

The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
Ze Zhang d466d5b809
Register fake_quantize related ops (#3522)
Register `aten.fake_quantize_per_channel_affine` and
`aten.fake_quantize_per_tensor_affine.tensor_qparams` ops

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-07-05 11:02:03 -07:00
Sagar Kulkarni 0fe74845da
[ONNX] Fix bug in ONNXToTorch PadOp's pads tensor rearrangement (#3485)
Fix the pad tensor rearrangement such that we change the representation
from [x1_begin, x2_begin, ..., x1_end, x2_end,...] to [xn_begin, xn_end,
...., x2_begin, x2_end, x1_begin, x1_end] where x1, x2 .. xn are the
dimensions of the pads tensor argument.

---------

Co-authored-by: zjgarvey <zjgarvey@gmail.com>
Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
2024-07-03 15:02:49 -05:00
Scott Todd ca0e906675
Fix `uint64_t` type. (#3519)
`u_int64_t` is nonstandard and does not exist in MSVC.
2024-07-02 16:06:20 +00:00
Yuanqiang Liu f1e3701caf
[Stablehlo] fix compareOp with scalar's lowering (#3518)
* use lhs tensor's element type as compute type when rhs is scalar.
* previously `a != 1.0`(a is a fp32 tensor) will lowering to `%6 =
stablehlo.compare EQ, %4, %5, FLOAT : (tensor<2x5xf64>, tensor<2x5xf64>)
-> tensor<2x5xi1>`
* now it will lowering to `%6 = stablehlo.compare EQ, %4, %5, FLOAT :
(tensor<2x5xf32>, tensor<2x5xf32>) -> tensor<2x5xi1>`
2024-07-02 15:31:06 +08:00
Yuanqiang Liu e2fbded49c
[Torch Dialect] improve argmax/argmin's decomposition to support keep… (#3514)
…dim=True when dim=None
2024-07-02 09:08:57 +08:00
Yuanqiang Liu 0e71a192d8
[Torch] support decomposition of aten.aminmax (#3513)
* unify decompisition of `aten.amax` and `aten.amin`
* support `aten.amax` with `dim=()`
2024-06-29 21:44:05 +08:00
Yuanqiang Liu f9fc741eef
[Stablehlo] support aten.any.dim, aten.min.dim (#3500)
* refactor `TorchToStablehlo/Reduction.cpp`
* add `ConvertAtenReduceWithIndicesOp` patterns
2024-06-29 16:53:33 +08:00
jinchen 3915db0a86
[ONNX] Add OnnxToTorch support for CenterCropPad (#3496) 2024-06-28 12:47:29 -07:00
zjgarvey af236dab66
Add support for multiple dynamic reassociation dims for unflatten.int (#3504)
Addresses an issue with onnx.Gather lowering to linalg:
<https://github.com/nod-ai/SHARK-Turbine/issues/242>

The builder for tensor.expand_shape, without an explicitly provided
output shape, fails to infer an output shape in the case of multiple
dynamic reassociation dims. I tried adding the output shape explicitly
for tensor.expand_shape, but ran into compilation issues later on (see
<https://github.com/iree-org/iree/issues/17760>).

This PR adds support by lowering this op to tensor.reshape when multiple
dynamic reassociation dims are provided.
2024-06-28 09:59:51 -07:00
Max191 a1c4089e71
Fix unused variable warning from assertion variable (#3512)
Inlines a variable into an assertion that is not used elsewhere to fix
build warnings.
2024-06-28 12:20:29 -04:00
Jiawei Wu f75cbb4df9
[torch dialect] emit aten.fmax/fmin and add decomposition patterns (#3510) 2024-06-29 00:07:55 +08:00
Phaneesh Barwaria 5a627c46b7
onnx.DFT basic support (#3463)
- adds support for DFT v20 on the FFT and IFFT path
- adds required skeleton code for IFFT ops to be recognised in TMlir
2024-06-28 20:08:43 +05:30
Christopher McGirr 7e6d76e997
[Torch] Fix torch.constant.int operation parsing (#3476)
Due to the custom operation parser, the print and parser were expecting
two different forms.

One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
2024-06-28 16:06:52 +02:00
Aart Bik 1f73895f93
[torch-mlir] bump to llvm/llvm-project@9b78ddf3b2 (#3491)
This bump triggered an upstream assert. Includes a WAR for #3506.

Also includes several things I needed to do to repro:

* When TORCH_MLIR_TEST_CONCURRENCY=1, test runs will be printed.
* Added TORCH_MLIR_TEST_VERBOSE=1 handling to enable verbose mode
(useful on CI).

---------

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
2024-06-27 19:28:02 -07:00
jinchen 6d0ca499e6
[ONNX] Add OnnxToTorch support for ReverseSequence (#3495) 2024-06-27 14:33:41 -07:00
Phaneesh Barwaria 39d1332008
add onnx loop support (#3408)
- Adds limited support for lowering onnx.Loop to primLoopOp
- lower in the pipeline`torch-to-scf` there is a check to see if loop is
for like. A primLoopOp is for like when the input condition is a
`trueBoolConstant`. To adapt the onnx to torch lowering to take
advantage of it, the implementation checks for specific op patterns in
the loodBody region and decides if loop is for like and uses the right
input condition op.
- to adapt the onnxLoopBody to torchLoopBody, we need to adapt the input
block arguments and set the correct output condition variable in the
loop body.
- scanOutput variables are currently not supported.
2024-06-27 17:08:44 +05:30
Matthias Gehre 6678e1a256
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.

We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.

For example
```
module {
  func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
                        %arg1: !torch.vtensor<[336,168,3,3],f32>, 
                        %arg2: !torch.vtensor<[336],f32>) 
                        -> !torch.vtensor<[32,336,56,56],f32> {
    %false = torch.constant.bool false
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 
    : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
      !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
   -> !torch.vtensor<[32,336,56,56],f32>
    return %3 : !torch.vtensor<[32,336,56,56],f32>
  }
}
```
would result in
```
[...]
  %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
  %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
    ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
    outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 08:43:10 +02:00
Suraj Sudhir 6eebe61bfe
[Tosa] Conversion from torch.__interpolate to tosa.resize() (#3488)
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-06-26 09:10:14 -07:00
Ramiro Leal-Cavazos e29191bd08
[LINALG] Broadcast `values` to shape of slize in `index_put` (#3487)
The `index_put` operation, `input[indices] = values`, allows for the
values to be any shape that is broadcastable to the slice
`input[indices]`. This commit adds broadcasting support to the Linalg
lowering of `IndexPutHackedTwinOp`.

Fixes: #3465
2024-06-26 08:59:49 +00:00
zjgarvey d2bc70f188
[TorchToLinalg][ONNX] Add Basic Determinant Support (#3481)
This adds support for a few ops:

- torch.linalg_det
- torch._linalg_det (if the LU and pivot returns are unused)
- onnx.Det

An scf loop is used, since the row reduction algorithm applied here has
some loop-carried dependencies.
The current support being added here is very basic, and only works if no
permutations are required during row reduction, and assumes the matrices
are non-singular.
2024-06-25 13:34:19 -05:00
zjgarvey 368fabf0c1
[ONNX] Basic Support for DeformConv (#3469)
This adds a torchvision op to torch-mlir and a path from onnx.DeformConv
to torchvision.deform_conv2d.

I'm not implementing the torch->linalg lowering for the torchvision op
yet, but posting this PR to get feedback on some of the choices being
made here and to flesh out the onnx frontend a bit.
2024-06-25 12:16:51 -05:00
zjgarvey e346c911f7
[ONNX] Add basic support for RoiAlign (#3493)
This adds an onnx->torch conversion for onnx.RoiAlign into
torchvision.roi_align or torchvision.roi_pool, and adds those two
torchvision ops to torch-mlir.
2024-06-25 11:02:45 -05:00
Vinayak Dev 02340408b7
[torch] Add OnnxToTorch lowering for Onnx.STFT op (#3492)
Adds OnnxToTorch lowering for `Onnx.STFT` op.
2024-06-25 19:00:45 +05:30
Vivek Khandelwal 3c3fbe4680
[ONNX] Add OnnxToTorch lowering for Onnx.Upsample Op (#3371)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-25 12:58:31 +05:30
Chi_Liu fc19709daa
[ONNX] Add averagepool dilations support (#3490)
- To fix dilations issue: https://github.com/llvm/torch-mlir/issues/3428
- Test by: https://github.com/nod-ai/SHARK-TestSuite/pull/268
2024-06-21 17:24:57 -07:00
Branko Trifkovic 98c6971a01
Implement lowering of torch.aten.triu_indices (#3451)
Closes
[nod-ai/SHARK-Turbine/issues/709](https://github.com/nod-ai/SHARK-Turbine/issues/709)

---------

Co-authored-by: Branko Trifkovic <branko.trifkovic@syrmia.com>
2024-06-21 16:16:38 -07:00
Matthias Gehre acd57a3520
Support fake_quantize_per_tensor_affine_cachemask (#3477)
Add a new op with shape/dtypes and decompose into
`fake_quantize_per_tensor_affine` when the second result is unused.

The xfail_set change is on ONNX because torch cannot export this op to
ONNX.
2024-06-21 07:15:31 +00:00
Vivek Khandelwal 83bfb6fb19
[ONNX] Add OnnxToTorch lowering for OptionalHasElement op (#3472)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-21 11:19:00 +05:30
Vivek Khandelwal d29ad4dfbd
[ONNX] Fix Onnx.Hardsigmoid lowering (#3239)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-21 11:18:14 +05:30
zjgarvey 694210f429
[TorchToLinalg] Fix Quantized Convolution Accumulator Type (#3459)
1. truncates zero-points to i32
2. modifies the default accumulator type for i8 from i64 to i32. 
3. now uses the input dtype to infer accumulator dtype.
2024-06-20 13:54:20 -07:00
Xinyu Yang c7d52f63b4
[stablehlo] add aten::_int_mm lowering (#3474)
as title
2024-06-20 16:10:31 +08:00
Vivek Khandelwal 822d763308
[ONNX] Add OnnxToTorch lowering for Optional, OptionalGetElement op (#3467)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-18 19:40:18 +05:30
Branko Trifkovic 676fa8cc09
Implement lowering of torch.aten.renorm (#3388)
Closes
[nod-ai/SHARK-Turbine/issues/689](https://github.com/nod-ai/SHARK-Turbine/issues/689)

---------

Co-authored-by: Branko Trifkovic <branko.trifkovic@syrmia.com>
2024-06-17 10:40:57 -07:00
Umang Yadav 59bade3376
[ONNX] Add missing "Abs" in GlobalLpPool (#3460)
Taking `abs` is required to mimic same logic as onnx/onnxruntime. 
Without `abs`, it wouldn't produce correct results for negative values. 

Reference code : 

f5b6f6dc26/onnxruntime/core/providers/cpu/nn/pool_functors.h (L604)


375c161c67/onnx/reference/ops/op_lp_pool.py (L31)
2024-06-17 11:17:16 +05:30
ptrifunovic98 4555629246
Implement lowering of torch.aten.kthvalue (#3360)
Closes
[nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620)
2024-06-15 11:18:39 +05:30
Manupa Karunaratne d2b663ece7
Add onnx op LRN lowering (#3432)
This commit adds support for lowering
Onnx LRN op to aten.
2024-06-14 16:44:43 +00:00
Arham Khan 09c988046c
[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op (#3380)
This implements the Onnx.NegativeLogLikelihoodLoss op using the
signature provided
[here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html)
by replacing it with a `NLLLossForward` op.

Additionally, I included a helper function `get_loss_reduction_enum` to
convert from a string `reduction` parameter to the corresponding
intended integer value since this is an operation that will be reused
for any loss function module. This differs from `get_reduction_enum` in
`TorchUpstream.cpp` which handles the `reduce` parameter from
`scatter_reduce` type operations.
2024-06-14 22:01:11 +05:30
Vivek Khandelwal 2ea2bc3948
[ONNX] Add OnnxToTorch Lowering for GroupNormalization op (#3458)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-14 16:18:53 +00:00
Umang Yadav 04c6479350
[ONNX] Add onnx parser for LpPool operator (#3449)
Similar to https://github.com/llvm/torch-mlir/pull/3435

Solves https://github.com/nod-ai/SHARK-Turbine/issues/728
2024-06-14 21:41:18 +05:30
Xinyu Yang 6f94c7b0aa
[Torch] Add support for Meshgrid (#3462) 2024-06-14 23:59:08 +08:00
Phaneesh Barwaria 919b599ebe
onnx.MaxPool add atenMaxPool1d lowering support (#3452)
fixes #3422
2024-06-13 15:37:11 +05:30
Vinayak Dev 39d882f7c9
[torch] Add OnnxToTorch lowering for the Col2Im op (#3424)
Adds OnnxToTorch lowering for the `onnx.Col2Im` op.
2024-06-13 08:42:06 +00:00
Surya Jasper de7f058a0e
[MLIR][ONNX] Add OnnxToTorch support for MaxRoiPool Op (#3395)
This PR adds OnnxToTorch support for MaxRoiPool op
2024-06-13 10:46:14 +05:30
Umang Yadav 9b76a2e3eb
[ONNX] add onnx lowering for global lp pool operator (#3435)
Solves https://github.com/nod-ai/SHARK-Turbine/issues/727

Uses AvgPool to implement GlobalLpPool similar to this
https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_lp_pool.py

cc: @vivekkhandelwal1
2024-06-13 10:37:08 +05:30
Lei Zhang 77d7f64472
Update to llvm/llvm-proect@27ac46e6be (2024-6-12) (#3454)
This would require to bump stablehlo at the same time.
2024-06-12 19:34:01 -07:00
Chi_Liu ae6f5e8251
[ONNX] Fix AveragePool attributes support (#3235)
Issues was found here https://github.com/nod-ai/SHARK-Turbine/issues/643
    - [ONNX] Fix padding attributes for onnx.AveragePool
    - [Linalg] Add countIncludePad false support for AtenAvgPool1/2dOp
    - [Linalg] Add an avg_pool2d countIncludePad False e2e tests
    - [Linalg] Fix conflict with AtenAvgPool3dOp
    - [Linalg] Fix e2e crash with AtenAvgPool1dOp
    - [Linalg] Add dynamic dim support for AtenAvgPool2dOp
    - [Linalg] Fix AvgPool2dDivisorOverrideModule crash
2024-06-12 12:16:43 -07:00
Suraj Sudhir 41d04a8995
[onnx] Resize supports default-valued attributes (#3450)
Handles onnx exporters emitting default-valued attributes.

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
2024-06-12 09:23:42 -07:00
zjgarvey de28c8540b
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
2024-06-12 10:37:22 +05:30
zjgarvey 7cd3368b20
[ONNX] Fix resize ceil numerics and add half_pixel_symmetric support (#3443)
This patch fixes several failing tests in our [external test
suite](https://github.com/nod-ai/SHARK-TestSuite/tree/main/iree_tests/onnx/node/generated),
and addresses some of the issues discussed in #3420
2024-06-11 22:35:50 -05:00
Matthias Gehre e07a0bfc54
onnx.resize: Add support for coordTfMode "half_pixel" (#3441)
half_pixel is also the default mode used by ONNX, see
https://onnx.ai/onnx/operators/onnx__Resize.html
2024-06-10 20:59:29 +02:00
Aart Bik d77bab37d1
[torch-mlir][sparse] re-enable all sparse tests (#3444)
this fixes the following issue:

https://github.com/llvm/torch-mlir/issues/3418
2024-06-10 11:19:32 -07:00
Vivek Khandelwal 5bc626465b
[ONNX] Lower Onnx.Concat lowering version (#3437)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-09 12:07:20 +05:30
Vivek Khandelwal d35b6b412a
[ONNX] Add OnnxToTorch Lowering for Sequence Ops (#3425)
This commit adds the lowering for SequenceAt, SequenceEmpty,
SequenceInsert, SequenceErase op

Signed-Off By: Vivek Khandelwal<vivekkhandelwal1424@gmail.com>
2024-06-08 09:58:11 +05:30
Yuanqiang Liu 689efc8917
[Torch] fix toBuiltinTensor() (#3415)
* Let `toBuiltinTensor()` reflects the original dtype of
`!torch.vtensor`.
* Backend handles dtype conversion themselves.
2024-06-08 09:36:32 +08:00
Rob Suderman 75af64fc12
[torch] Add support for f8 types for linalg conversion (#3436)
Linalg conversion requires mapping for f8 types
2024-06-07 13:59:38 -07:00
aldesilv f794582b18
add resize nearest mode round_prefer_floor, round_prefer_ceil, ceil (#3421) 2024-06-07 14:04:11 -05:00