Commit Graph

781 Commits (9e2fe47c5dacfe935ccb3292338c767358305e07)

Author SHA1 Message Date
Yuanqiang Liu fab2696489
[Torch] support aten.trunc (#3219)
decompose `trunc(x)` to `sign(x) * floor(abs(x))`
2024-04-24 14:32:33 +08:00
Yuanqiang Liu dc470e65c8
add torch.qint32 to dtype-spec in TorchTypes.td (#3206) 2024-04-24 11:49:26 +08:00
Xinyu Yang 4da3d714cc
[Torch] Support AtenProdOp on linalg and stablehlo (#3215) 2024-04-24 11:14:04 +08:00
Vivek Khandelwal 3c252cdd44
[onnx] Add `onnx-to-torch` lowering for random ops (#3193)
This commit adds the OnnxToTorch lowering for Onnx's RandomNormal, RandomNormalLike, RandomUniform, and RandomUniformLike op.
2024-04-22 22:28:07 +05:30
Vivek Khandelwal 6abc7371c8
[MLIR][TORCH] Fix OnnxToLinalg lowering issue for Squeeze and Unsqueeze op (#2991)
This commit also cleans up the OnnxToTorch lowering for the Squeeze and
Unsqueeze op and adds the support for handling edge cases.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-22 08:52:42 +00:00
penguin_wwy e5bdd71baf
[Torch] Emit and decompose prims.iota op (#3132) 2024-04-21 19:45:01 -07:00
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
Xinyu Yang d4313eed4a
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
Xinyu Yang d2ba956e69
[Torch] Support Aten_CastLongOp. (#3160)
By canonicalize Aten_CastLongOp into AtenToDtypeOp
2024-04-17 21:58:32 +08:00
IanWood1 5708ee7ec9
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)
- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
2024-04-15 13:45:10 -07:00
penguin_wwy d4a30b7e67
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)
We should prefer functional style as the method style is deprecated
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
(https://mlir.llvm.org/deprecation/)
2024-04-11 06:47:35 -07:00
Xinyu Yang 42a16fa912
[Torch] Support Aten_CastFloatOp. (#3115)
By canonicalize Aten_CastFloatOp into AtenToDtypeOp
2024-04-09 11:06:53 +08:00
Xida Ren (Cedar) dd967eb199
[ONNX] Support onnx.LSTM (#2969)
This PR only performs a lit test. In lieu of an e2e test, https://github.com/nod-ai/SHARK-TestSuite/pull/142 makede sure that the lowering works & the numbers check out.

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-04-08 12:23:33 -07:00
Xinyu Yang 84c24e5771
[Torch] Support Aten__And__ScalarOp (#3114) 2024-04-08 20:24:17 +08:00
Yuanqiang Liu 2c56ef9252
[Torch Dialect] canonicalize aten.sign to aten.sgn (#3112)
* `aten.sign` is a sub-set of `aten.sgn` (`aten.sgn` support complex
type).
2024-04-08 20:05:42 +08:00
Vivek Khandelwal ce7d4f1660
[MLIR][TORCH] Fix Onnx.ReduceSum lowering for failing e2e tests (#3095)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-03 09:57:19 +05:30
Rob Suderman f97cd4893f
[torch] Improve shape inference for dynamic shapes (#3091)
Shapes can be processed as tensors to represent the set of dimensions.
As reshapes take a list of scalars this can result in a single dynamic
dimension blocking the adjacent static dimensions.

This pass attempts to de-couple tensor computations related to shapes
and propagate values to better support lowering scalar tensor
computations.
2024-04-02 16:19:57 -07:00
Thomas Dietert 3c33dbd987
[MLIR][Torch] Canonicalize torch.from_i1 and torch.to_i1 (#3067)
When lowering `torch.aten.convolution`, it is expected that the
'transposed' argument is a torch.constant operation. In some cases, the
argument was a `from_i1` operation converting an `arith.constant`
operation into a torch.bool. This is not wrong semantically, but instead
of generalizing the legality of the `torch.aten.convolution` op, we
canonicalize `arith.constant` ops followed by `from_i1` ops to
`torch.bool` ops.

For example:
```
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.convolution'(0x124705b90) {
  %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : unimplemented: only constant transposed supported.      <-- Resolved by this PR
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported Scalar to Tensor like op
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported elementwise op
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported reduce op
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<stdin>:21:11: error: failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
    %17 = torch.operator "onnx.Conv"(%arg0, %0, %1) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [5 : si64, 5 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>) -> !torch.vtensor<[1,10,24,24],f32> 
          ^
<stdin>:21:11: note: see current operation: %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>
```

Additionally, we require the canonicalization of `to_i1` operating on a
torch.constant bool to an `arith.constant ... : i1` for the e2e tests to
pass successfully.
2024-04-01 14:25:51 -07:00
Stella Laurenzo 6d680ff445
[ods] Allow all tensor returns to be optional. (#3082)
This was found while tracing backwards graphs: the convolution_backwards
op will return None if the first result is not needed. Confirmed by
defining a custom op with a `Tensor` return signature and having its
meta kernel return None.
2024-03-29 23:09:34 -07:00
Yuanqiang Liu 0a581a97a7
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
2024-03-27 09:51:58 +08:00
Rob Suderman 14b548f968
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
2024-03-26 12:41:40 -07:00
schnkmwt 1fcbfa87ec
Implement linalg lowering of diag_embed torch op (#2885)
This PR adds lowering of diag_embed to linalg dilect.
Tracked in https://github.com/nod-ai/SHARK-Turbine/issues/288

---------

Co-authored-by: sachink <sachink@xilinx.com>
2024-03-22 16:32:50 -07:00
zjgarvey 99b3a5f117
Converts all Adaptive Pooling Ops to Linalg (#2808)
The previous conversions for AtenAdaptiveAvgPool1dOp and
AtenAdaptiveMaxPool2dOp are refactored into a general templated
conversion that works for all of the AtenAdaptive...PoolNdOp's.

New support is added for the following ops:

1. AtenAdaptiveMaxPool1d
2. AtenAdaptiveMaxPool3d
3. AtenAdaptiveAvgPool3d

Support is also provided for passing inputs without batch dimensions.
For example, applying adaptive_avg_pool2d to an input tensor of rank 3.

After [pytorch #118162](https://github.com/pytorch/pytorch/pull/118162)
gets down to torch-mlir, I'll add a test for AdaptiveMaxPool1d with
return_indices (which will pass with that upstream fix).

---------

Co-authored-by: James Newling <james.newling@gmail.com>
2024-03-22 11:05:20 -07:00
Yuanqiang Liu 4282eb9e76
[Torch Dialect] support aten.fake_quantize_per_tensor_affine (#3014) 2024-03-15 08:53:29 +08:00
Yuanqiang Liu 43c6996a31
[Torch Dialect] add folder for aten.ceil and unify patterns of ceil, … (#3010)
…floor, round
2024-03-14 07:41:58 +08:00
ptrifunovic98 524ff99216
Implement lowering of torch.aten.linalg_cross (#2986)
Closes
[nod-ai/SHARK-Turbine#497](https://github.com/nod-ai/SHARK-Turbine/issues/497)
2024-03-13 12:17:22 -07:00
Nithin Meganathan 5ecc1d5c0d
Align softmax accumulation types with Torch's CUDA implementation (#2996) 2024-03-12 15:07:45 -07:00
Yuanqiang Liu 229ca3a9e1
[Torch Dialect] emit aten::mul and add folder (#3007) 2024-03-11 19:59:34 +08:00
Yuanqiang Liu a3fe130f73
[Torch Dialect] emit aten::warn (#3003)
* torch-mlir may not handle `aten.warn`. But it could be handled by
custom users' backend which involves torch-mlir.
2024-03-10 08:29:08 +08:00
Rob Suderman 0723584936
[torch] Add folder for torch.aten.*.Scalar comparisons (#3000)
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
2024-03-08 13:44:00 -08:00
Ze Zhang aa7c9a9653
e2e support aten.linalg_norm to aten.linalg_vector_norm (#2953)
Add e2d support for `aten.linalg_norm` by decompose it to
`aten.linalg_vector_norm`.

Lowering to `aten.linalg_matrix_norm` is still unsupported.

To Test: 

`python -m e2e_testing.main -v`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-03-05 16:31:01 -08:00
Rob Suderman bc0527676b
[torch] Add support for `torch.split_with_sizes` via decompose (#2979)
Convert to individiual slices and tuple together as a list.

---------

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
2024-03-05 15:01:21 -08:00
Rob Suderman 61f0a5facf
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
2024-03-01 21:41:12 -08:00
Vivek Khandelwal 579ac8b666
[MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op (#2954)
This commit adds the support for scalar conversion to byte. 
This commit also fixes the OnnxToLinalg lowering issue for Onnx.Sub and
Onnx.Sum op.
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/466 
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/467

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-29 21:48:46 +05:30
Rob Suderman e48fe45886
[onnx] Import `onnx` import to pass remaining tests (#2951)
Finish supporting importing the vast majority of `onnx` operations. This
includes:
- region support
- region value inherentance
- `torch.string` support
- `torch.list` support
- `torch.optional` support
2024-02-28 12:18:02 -08:00
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Rob Suderman e30a083aff
[torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940)
We collapsed and broadcasted scatter indices to a single element
version. We should instead upport `tm_tensor.scatter`s support for
multiple indices and the implicitly broadcasted behavior. This avoids
the serialization and materializing a needlessly large indices tensor.
2024-02-27 11:46:57 -08:00
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
ptrifunovic98 c5a1da1910
Implement lowering of torch.aten.norm.Scalar (#2899)
Closes
[nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365)
2024-02-26 08:46:56 -08:00
Andreas Falkenberg 55dc8deb92
[torch] GridSample TorchToLinalg lowering (#2883)
Lowers `torch.grid_sample` to the equilvalent `linalg` representation.
2024-02-23 09:14:38 -08:00
Rob Suderman df2aa1a369
[torch] Fixed edge conditions for strided slicing (#2929)
Strided slicing can occur with a negative stride. In these cases we need
to bound end differently. This included removing a function that was
generating bad limits.
2024-02-21 21:28:44 -08:00
Stella Laurenzo 4446fa00d8
Migrate passes in TorchConversion to use FunctionOpInterface. (#2935)
This enables better re-use in downstreams which use different func
implementations and should have no impact on those that don't except in
opt pipelines if using the old form. With interfaces, explicit pipelines
via `--pass-pipeline=` must be used.
2024-02-20 08:54:02 -08:00
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman e80054a3cc
[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)
Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
2024-02-19 10:28:23 -08:00
aldesilv d29157b33f
OnnxToTorch support for onnx.InstanceNormalization op (#2710)
https://github.com/nod-ai/SHARK-Turbine/issues/327
2024-02-19 19:53:48 +05:30
Vivek Khandelwal d6d1a173dc
[MLIR][Torch] Add OnnxToTorch and TorchToLinalg support for trig ops (#2903)
This commit adds the OnnxToTorch lowering for cosh, acosh, asin, asinh,
and atanh op.
This commit also adds the TorchToLinalg lowering for acosh, asin, asinh,
and atanh op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-14 11:58:09 +05:30
Rob Suderman e9cdd6cbc5
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We
map these back to generics so they appropriately match to the high
performance versions. This is done for the attention operation.
2024-02-13 21:18:01 -08:00
Rob Suderman c0f139be0f
[torch] Add `torch.aten.eq.Tensor` comparison folder (#2889)
Added a folded for a equals operator. This allows an equivalent
comparison folder, primarily for when shape computations occur small
size tensor.
2024-02-09 15:02:20 -08:00
Rob Suderman 7d33ba69ac
[torch] Folder for torch.aten.select.int for splat cases (#2890)
If the input or result is a splat value we can just constant fold the
result. This is common for shape computations and can help with shape
inference.
2024-02-09 14:02:54 -08:00
Franz Haniel 4cc62aeb24
Implement trace (#2790)
The lowering decomposes AtenTraceOp into an AtenDiagonalOp followed by
AtenSumOp.

The progress is tracked in
https://github.com/nod-ai/SHARK-Turbine/issues/333.

---------

Co-authored-by: Franz Haniel <franz.haniel@amd.com>
2024-02-09 08:00:24 -08:00