Commit Graph

352 Commits (42a16fa9121f5e9725786d5e3ef8a4da0cdc0d3c)

Author SHA1 Message Date
Rob Suderman 9d9a05366e
[torch] Fix aten.squeeze lowering to use result shape (#3106)
Squeezes can be ambiguous without the output shape information. For
instance (1, 1, 256) squeezed can be either (1, 256) or (256). We need
to check the resulting shape to know what the shape should look like.
2024-04-04 09:43:12 -07:00
Rob Suderman f97cd4893f
[torch] Improve shape inference for dynamic shapes (#3091)
Shapes can be processed as tensors to represent the set of dimensions.
As reshapes take a list of scalars this can result in a single dynamic
dimension blocking the adjacent static dimensions.

This pass attempts to de-couple tensor computations related to shapes
and propagate values to better support lowering scalar tensor
computations.
2024-04-02 16:19:57 -07:00
Xinyu Yang ac1cd3d78a
[Torch] Support AtenDivTensorModeOp with static int input for linalg and stablehlo backend (#3088) 2024-04-02 17:28:53 +08:00
Thomas Dietert d2432bbe5a
[MLIR][Torch] Do not convert bias tensor to element type if NoneType (#3072)
The `convertTensorToElementType` function expects it's argument to have
a valid tensor type that is not `Torch::NoneType`. This PR checks that
the bias tensor is not of type `Torch::NoneType` before calling
`convertTensorToElementType` on the bias tensor argument in the
`matchAndRewrite` member function of the `ConvertAtenConvolutionOp`
class.
2024-04-02 14:19:26 +05:30
ptrifunovic98 1c8c47d483
Add complex support for aten.norm and similar operations (#3052)
Add support for complex-type input tensors for norm, vector norm, and
Frobenius norm operations.
2024-04-02 14:03:30 +05:30
zjgarvey 532d297c46
[ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089)
See the related issues here:
[SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556)

1. Adds uint8 casting to onnx.Cast op
2. Fixes an issue with onnx.DequantizeLinear when the scale comes with
shape [1].
3. Adds support for unsigned types in an AtenItemOp folder
4. Adds a simpler quantized model for easier debugging
5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm]
patterns to [transpose -> quant -> mm].
6. Moved some xfails that are still not passing, but for different
reasons than onnx.cast failures.
2024-04-01 16:21:05 -07:00
Vivek Khandelwal 6844c84702
[MLIR][Torch] Fix OnnxToLinalg lowering for AvgPool op (#3076)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-01 22:14:14 +05:30
Rob Suderman 14b548f968
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
2024-03-26 12:41:40 -07:00
schnkmwt 1fcbfa87ec
Implement linalg lowering of diag_embed torch op (#2885)
This PR adds lowering of diag_embed to linalg dilect.
Tracked in https://github.com/nod-ai/SHARK-Turbine/issues/288

---------

Co-authored-by: sachink <sachink@xilinx.com>
2024-03-22 16:32:50 -07:00
zjgarvey 99b3a5f117
Converts all Adaptive Pooling Ops to Linalg (#2808)
The previous conversions for AtenAdaptiveAvgPool1dOp and
AtenAdaptiveMaxPool2dOp are refactored into a general templated
conversion that works for all of the AtenAdaptive...PoolNdOp's.

New support is added for the following ops:

1. AtenAdaptiveMaxPool1d
2. AtenAdaptiveMaxPool3d
3. AtenAdaptiveAvgPool3d

Support is also provided for passing inputs without batch dimensions.
For example, applying adaptive_avg_pool2d to an input tensor of rank 3.

After [pytorch #118162](https://github.com/pytorch/pytorch/pull/118162)
gets down to torch-mlir, I'll add a test for AdaptiveMaxPool1d with
return_indices (which will pass with that upstream fix).

---------

Co-authored-by: James Newling <james.newling@gmail.com>
2024-03-22 11:05:20 -07:00
Rob Suderman 3a56714bff
[torch] Fix clamp ranges on quantize_per_tensor on unsigned (#3018)
SExtValue was used for `int` and `uint` clamp values. This caused the
result to always be outputed as `zero`.
2024-03-20 13:37:47 -07:00
Nithin Meganathan 798bfd7dff
Adds accumulator types in TorchToLinalg for `AtenMmOp` and `AtenConvolutionOp` (#3027) 2024-03-14 16:40:40 -07:00
Rob Suderman 1964208d19
[onnx] Fix constant pad for dynamic shape (#2989)
The current padding operation was not functional for dynamic shapes.
Updated and enabled tests so that onnx.pad tests pass.

Work TBD for reflection padding.
2024-03-07 13:29:50 -08:00
Rob Suderman a78659742a
[onnx] Migrate `onnx.ReduceMax` to match `onnx.ReduceMin` (#2981)
This mostly copy-pastes the reduce minimum implementation to reduce max
to improve test coverage. We also improve the aten lowering for min/max
dim for unsigned types.
2024-03-06 16:48:21 -08:00
Andreas Falkenberg ea76dd12ba
[onnx][torch] Gridsampler E2E test and corrections of gridsampler (#2987)
The addition of an e2e test is actually provided in the Shark-Testsuite.
This adds 2 test cases for the gridsampler e2e test. 
Also as intended there were some items found which needed correction, so
the Gridsampler op has also a change.
2024-03-06 10:56:58 -08:00
Rob Suderman 19d4888278
[torch] Make torch.aten.unflatten lower directly to linalg (#2971)
Existing lowering via aten.view does not work as well for dynamic shapes
as the lowering to tensor.expand must re-infer dynamic shape matching.
Better to directly lower.
2024-03-04 10:17:42 -08:00
Rob Suderman d030bffc62
[torch] Support `aten.view` rank-0 collapse (#2965)
Collapsing to a rank-0 tensor using `aten.view` was currently bailing
out. Added the special case.
2024-03-01 12:31:07 -08:00
Vivek Khandelwal 579ac8b666
[MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op (#2954)
This commit adds the support for scalar conversion to byte. 
This commit also fixes the OnnxToLinalg lowering issue for Onnx.Sub and
Onnx.Sum op.
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/466 
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/467

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-29 21:48:46 +05:30
mmakevic 76b81e0ccd
Implement lowering of torch.aten.fmod.Tensor (#2767)
Closing https://github.com/nod-ai/SHARK-Turbine/issues/351
2024-02-29 11:22:03 +05:30
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Rob Suderman 4a7a7d76f8
[onnx] Fix ReduceMean lowering to torch (#2956)
Torch lowering only supported the most recent version. Refactored the
lowering so more easily handle default values and optional operands /
attributes.
2024-02-27 22:48:07 -08:00
Vivek Khandelwal d628b5fd06
[MLIR][TORCH] Add support for tanh approximation for Gelu op (#2941)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/461

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-27 19:26:01 +05:30
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
ptrifunovic98 c5a1da1910
Implement lowering of torch.aten.norm.Scalar (#2899)
Closes
[nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365)
2024-02-26 08:46:56 -08:00
Andreas Falkenberg 55dc8deb92
[torch] GridSample TorchToLinalg lowering (#2883)
Lowers `torch.grid_sample` to the equilvalent `linalg` representation.
2024-02-23 09:14:38 -08:00
Rob Suderman df2aa1a369
[torch] Fixed edge conditions for strided slicing (#2929)
Strided slicing can occur with a negative stride. In these cases we need
to bound end differently. This included removing a function that was
generating bad limits.
2024-02-21 21:28:44 -08:00
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman fd08578bdb
[torch] Support dynamic step size for `torch.slice` (#2922)
For some reason we did not directly use the step size dynamically
despite its constructed using the dynamic value.
2024-02-19 10:26:21 -08:00
Rob Suderman d65925a8b4
[onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914)
Sample compilation crashes due to sigmoid with integer inputs/outputs.
This fix avoids crashing but still experiences an error.
2024-02-16 13:35:25 -08:00
Rob Suderman 074f112d6a
[onnx] Add testing using the `onnx` compilation using torch tests (#2795)
We can route the torch tests via `onnx` using the `torch.onnx.export`
tooling. We can then reimport, lower to torch, and compile to linalg to
validate the onnx path is working correctly.

The current implementation exposes some failures in the `onnx` path so
we cannot enable the onnx test suite yet due to segmentation faults.
2024-02-15 10:17:13 -08:00
Vivek Khandelwal d6d1a173dc
[MLIR][Torch] Add OnnxToTorch and TorchToLinalg support for trig ops (#2903)
This commit adds the OnnxToTorch lowering for cosh, acosh, asin, asinh,
and atanh op.
This commit also adds the TorchToLinalg lowering for acosh, asin, asinh,
and atanh op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-14 11:58:09 +05:30
Xida Ren (Cedar) bfb93cb99f
Fix test_add_uint8 failure to lower to linalg (#2893)
By updating convertScalarToDtype invocation pass original source and
destination datatypes for the add op. Also fixes a potential problem
with the sub op.

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-12 09:19:39 -08:00
Rob Suderman d83b576c6e
Bump LLVM to llvm/llvm-project@bb180856ec (#2895)
Includes some minor first for `AffineMap::inferFromExprList`
2024-02-09 14:07:49 -08:00
Avinash Sharma 9659a436d1
Add lowering support for math::AbsIOp (#2875)
There is no lowering support for math::AbsIOp, so if the operand is an
integer type, it will fail to lower to math::AbsFOp since the op operand
#0 must be floating-point-like.
2024-02-08 14:53:40 -08:00
mmakevic 32dbf99ce2
Implement lowering of torch.aten.all.dim (#2873)
Lowering of torch.aten.all.dim to linalg.

Per PyTorch documentation:

> This function matches the behaviour of NumPy in returning output of
dtype bool for all supported dtypes except uint8. For uint8 the dtype of
output is uint8 itself.

Since there is no support for ui8 in torch-mlir currently
(https://github.com/llvm/torch-mlir/pull/1384#issuecomment-1260011334)
implementation returns failure for that case.
2024-02-07 12:34:52 -08:00
Rob Suderman e3faef5224
[onnx] Convert `onnx.QLinearConv` to `torch` (#2851)
Leaning on the QDQ functionality in torch we can support the QLinearConv
operation by piggybacking through `torch.Convolution`. This includes
some changes such as allowing the `onnx` rewriter to run recursively.
Doing so allows `QLinearConv` to decopmose to `onnx.Convolution` which
is then lowered to `torch`.
2024-02-05 16:09:41 -08:00
Rob Suderman 34f6948533
[torch] Support `!countIncludePad` when unpadded for average pool (#2836)
We do not support average pool when `countIncludePad is set to false.
However if the input is unpadded then the setting of the boolean is
unneeded. Extended use by checking if padding is zero before rejecting
the lowering.
2024-01-31 15:09:36 -08:00
Rob Suderman 25a5a22cbd
[torch] Support `torch.convolution` quantized lowering to `linalg` (#2811)
Linalg has quantized specific operations. We can lower to these
operations when there is a known zeropoint and scale operations. This
allows the `convolution` to occur with lower bitwidth's, improving the
overall performance.
2024-01-30 13:46:47 -08:00
Quinn Dawkins 494089d53d
Clang format refresh (#2812)
After noticing a number of commits with unrelated formatting changes,
I think something was changed with clang-format at one point and we're
seeing a number of unrelated changes. Doing a refresh can help avoid
this.

The changes made here came from
```
find lib -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find include -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find projects -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
```
2024-01-29 12:59:33 -05:00
Aart Bik 46a25d7241
[torch-mlir][sparse] preserve sparsity during lowering torch to linalg (#2809)
This preserves sparsity at the most obvious places of lowering TORCH
tensors to MLIR RankedTensorType tensors. Other places are marked for
audit. With some initial lowering tests.
2024-01-26 10:54:59 -08:00
Rob Suderman 2ef228328f
[torch] `torch.dequantize` for per channel tensors to` linalg` (#2769)
Support a lowering for dequantization for per channel tensors from
`torch` dialect to a linalg decomposition. Tested via a numerical
`torch` test.
2024-01-25 16:40:21 -08:00
Rob Suderman f6f890520b
[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing
dequantization into dense operations. As a validation we compare
to the dequant+mm torch implementation.
2024-01-24 14:02:50 -08:00
zjgarvey c531f5495b
AtenAdaptiveMaxPool2d Conversion to Linalg (#2779)
The logic here is very similar to the conversion for AdaptiveAvgPool1d
#2661 with a few modifications:

1. buffVal = -inf instead of 0
2. the main linalg generic op accumulates a max, instead of a sum, to
the first output tensor
3. avg pooling requires dividing the sum pool by the kernel width, which
we stored as an auxilliary tensor (kSizeTensor). Here, the auxiliary
tensor will be recording the indices. Strangely enough, the only
signature available for this function is to return indices, and it
appears that they must be computed whether the user desires them or not.
See
[pytorch/torch/nn/functional.py](https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L1174).

Before writing other adaptive pooling conversions, the logic of this
decomposition should be rolled into a helper function that will work for
both max and avg pooling ops. Even the auxiliary tensor should likely be
automated. This code was written in a slightly more tedious way than
strictly necessary (often using loops to fill SmallVectors up to rank-2,
which is only two in this case), in order to more easily facilitate the
transition to a helper function.
2024-01-24 09:09:56 -08:00
Xida Ren (Cedar) ccaac85788
implement aten.conv1d, aten.conv3d, and aten.conv_tbc (#2757)
convolution with [time,batch,channel] ordering, as opposed to the
default [batch, channel, time]. Currently implementing by transposing
the input and output, but may need to get its own implementation in the
future because this is supposed to be an op that gives a speedup. This
is used by fairseq
(https://github.com/facebookresearch/fairseq/issues/172).

(in case you were wondering like me, this is different from transposed
convolution. Transposed convolution has fractional strides).

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Frederik Harwath <frederik.harwath@amd.com>
2024-01-23 21:30:03 -08:00
Ramiro Leal-Cavazos 5883ef0f21
Fix unused variable warnings (#2775) 2024-01-22 11:05:55 -08:00
Franz Haniel b9806cfa38
[TorchToLinalg] Add lowering for torch.aten.diagonal (#2632) 2024-01-22 12:47:13 -05:00
James Newling 50ac3b1912
g++ build fix (#2778)
Introduced in 704cfdaf08 of @wu-s-john 

g++ compiler error: 

Pooling.cpp:177:13: error: explicit specialization in non-namespace
scope ‘class

Design looks good, g++ is just freaking out for no good reason.
Un-nesting the template classes fixes the error.

We don't have g++ CI. This hopefully happens infrequently enough that we
can just fix manually. My service to those folks who really like
building with g++... :)
2024-01-19 19:12:29 -08:00
John Wu 704cfdaf08
Add aten.pool_max3d support to torch-to-linalg (#2735)
Added verification logic to the abstract_interpreter_lib_gen.py

Also made some unit tests

Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).

To achieve this, we would need to lower our code using the `linalg`
dialect.


Turns out the pooling code in `linalg` looks like this.

```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
                        %strides: memref<3xindex>, %dilations: memref<3xindex>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
    %C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
    %D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
    %H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
    %W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>

    %kernel_d = memref.load %K[%c0] : memref<3xindex>
    %kernel_h = memref.load %K[%c1] : memref<3xindex>
    %kernel_w = memref.load %K[2] : memref<3xindex>
    %stride_d = memref.load %strides[%c0] : memref<3xindex>
    %stride_h = memref.load %strides[%c1] : memref<3xindex>
    %stride_w = memref.load %strides[2] : memref<3xindex>
    %dilation_d = memref.load %dilations[%c0] : memref<3xindex>
    %dilation_h = memref.load %dilations[%c1] : memref<3xindex>
    %dilation_w = memref.load %dilations[2] : memref<3xindex>

    linalg.generic {
        indexing_maps = [
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>,  // Map for input tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>,                                              // Map for kernel tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)>                                            // Map for output tensor
        ],
        iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
        doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
    } ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
        ^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
            %max_val = arith.maxf %input_elem, %output_elem : f32
            linalg.yield %max_val : f32
    }
    return
}

```

This was implemented based on it's source code with the adjustments
mentioned above:

4ca1b5e094/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (L5647)

Issues related to this can be found here

https://github.com/nod-ai/SHARK-Turbine/issues/324
2024-01-19 21:09:46 +05:30
Ilija Kalinić faa4517e83
Implement lowering of torch.aten.remainder.Tensor (#2763)
Closes nod-ai/SHARK-Turbine#349
2024-01-19 18:09:08 +05:30
lisaliu1 09421b1cf3
[TorchToLinalg] Add lowering for aten.replication_pad2d (#2715)
Co-authored-by: Lisa Liu <lingl@xilinx.com>
2024-01-15 14:02:27 -05:00