Commit Graph

397 Commits (d49eabb3fce6e24fa19b344096b7c5d61e367115)

Author SHA1 Message Date
Stephen Baione d49eabb3fc
Add Op for `torch.aten.unfold` (#3772)
# Description

Implementation of the op for `torch.aten.unfold`: [TorchToLinalg Op
Support #347](https://github.com/nod-ai/SHARK-ModelDev/issues/849)

Documentation of op can be found here: [PyTorch
Docs](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)

For this op, we apply a sliding window of some `size` along a single
`dimension`, with `step` in between iterations.

`Declaration: aten::unfold(Tensor(a) self, int dimension, int size, int
step) -> Tensor(a)`

The resulting `unfolded` tensor modifies the shape of `dimension` to be
equal to the number of blocks that the sliding windows extracts/inserts,
with an additional dimension of `size` appended (the number of cols of
the output tensor directly translates from the size of the sliding
window).

So if we had a tensor of rank 3 (A x B x C), with dimension = 1, size =
2 and step = 2:

    (A x B x C) |=> (A x (B - size) // step + 1 x C x size)

After extracting the window from the input tensor, we insert the (1 x
size) slice into the output tensor. We can make this simpler by mapping
the output indices from the input indices, like they do in the official
implementation:

[PyTorch
Code](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L1694)
2024-10-08 21:10:43 +00:00
Vivek Khandelwal 614fcdd153
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
This commit adds the support for the 1-d depthwise convolution as a
special case of 1-d group convolution.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-08 10:48:47 +05:30
Justin Ngo b08d08682f
[TOSA] Add legalization for fill, flip, and round (#3768)
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and
aten.round
- Fix torchScalarToTosaTensor function to correctly convert Torch scalar
input to TOSA tensor
- Update xfail_sets.py with new e2e results
- Update basic.mlir with LIT tests for new ops


Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-07 10:28:26 -07:00
Chi_Liu f4840ed886
[ONNX] Fix onnx.ScatterElements with AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext dialect (#3754)
- To fix issue onnx.ScatterElements: https://github.com/nod-ai/SHARK-ModelDev/issues/823
- E2E test: https://github.com/nod-ai/SHARK-TestSuite/pull/363
2024-10-05 22:22:41 -07:00
Justin Ngo e9ed4af9ce
[TOSA] Add legalization for aten.index_select (#3760)
- Add Torch to TOSA legalization for aten.index_select
- Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to
correctly convert Torch indices to TF-style indices, which is used in
convertGatherNdOp
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new LIT test for aten.index_select

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-04 12:24:22 -07:00
Rob Suderman 2374b9e02d
Bump to llvm/llvm-project@e813750354 (#3765)
Includes stablehlo bump
2024-10-04 12:08:35 -07:00
zjgarvey f08bfc4ff8
[ONNX] simplify shapes fed to broadcast in Expand lowering (#3756)
Addresses ~200 onnx model compile failures in
<https://github.com/nod-ai/SHARK-TestSuite> related to
<https://github.com/iree-org/iree/issues/18631>.

This change simplifies the result of the generated broadcast op
substantially, but reduces the case coverage slightly.

The case which will become unsupported: 
- trying to actually broadcast a dynamic dim that is secretly 1. 

When does this case appear in practical scenarios?
- for a model where onnx shape inference cannot figure out that a dim
should be 1.

Why do I think we should not support this case for now?
1. For all models with dynamic dim expand ops, the previous path
uniformly generates uglier linalg IR (making it harder for IREE to fuse
properly with other ops).
2. For models failing shape inference castastrophically enough to fail
to see a dim is statically 1, we can try to apply constant folding in
the onnx model before importing.

Leaving this as a draft PR, since it may be more appropriate to fix the
compilation failure in IREE rather than torch-mlir.

### Example of broadcast required in previous path:

```mlir
    %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %306 = linalg.index 0 : index
      %307 = linalg.index 3 : index
      %308 = arith.index_cast %285 : i64 to index
      %309 = arith.cmpi eq, %308, %c1 : index
      %310 = arith.select %309, %c0, %306 : index
      %311 = arith.index_cast %286 : i64 to index
      %312 = arith.cmpi eq, %311, %c1 : index
      %313 = arith.select %312, %c0, %307 : index
      %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_79 : i1
    } -> tensor<?x12x?x?xi1>
```

### Example of broadcast with simplified shape list:

```mlir
    %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) {
    ^bb0(%in: i1, %out: i1):
      linalg.yield %in : i1
    } -> tensor<?x12x?x?xi1>
```
2024-10-03 20:11:51 -05:00
Rob Suderman 9ab0db5789
[torch] `torch.aten.complex` operation with lowering (#3738)
Add the operation with lowering to linalg. Includes a test for
end-to-end correctness.
2024-10-03 11:09:52 -07:00
Justin Ngo 5eab669c4a
[TOSA] Add legalization for aten.diagonal (#3740)
- Add lowering from Torch to TOSA for aten.diagonal
- Clean up some code
- Update xfail_sets.py with the new e2e results
- Update basic_mlir with the new op mlir test

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I99bed685455752d09ed96edd837c4dfbee152701

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-30 08:24:31 -07:00
Yuanqiang Liu 5f74de5ba0
[Stablehlo] support aten.all.dim (#3746) 2024-09-30 15:59:27 +08:00
yyp0 eb4e59e189
[Torch] support binary_cross_entropy_with_logits decomposition (#3741) 2024-09-29 17:41:20 +08:00
Xida Ren (Cedar) 9938abf25e
AtenCumprodOp (#3737) 2024-09-26 18:17:22 -04:00
yyp0 335cf5f6d0
[stablehlo] support aten_adaptive_max_pool1d lowering (#3728) 2024-09-26 11:42:38 +08:00
Justin Ngo 3f79a2982a
[TOSA] Extend Torch to TOSA legalization coverage (#3718)
- Add Torch to TOSA legalization for the following ops:
  + aten.logical_not
  + aten.logical_xor
  + aten.cos
  + aten.sin
  + aten.pow.Scalar
  + aten.pow.Tensor_Tensor
  + aten.erf
  + aten.bitwise_and.Scalar
  + aten.bitwise_left_shift.Tensor
  + aten.bitwise_right_shift.Tensor
  + aten.le.Tensor
  + aten.le.Scalar
- Update e2e tests in xfail_sets
- Update basic.mlir with newly legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-20 14:33:55 -07:00
Justin Ngo abaff58c6d
[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops support (#3717)
- Add legalization for aten.div rounding mode:
  + trunc: rounds division results towards zero
  + floor: rounds division results down
- Add legalization for aten.remainder.Scalar and aten.fmod ops
- Add legalization for aten.ge.Tensor op
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-20 13:34:09 -07:00
Rob Suderman 5ce48dfacd
[torch] Fix attention on linalg for dynamic shapes (#3714)
Current version does not work for a mixture of dynamic and static shaped
batch dimensions. Rework to grab the correct dynamic shapes.

---------

Co-authored-by: dan <danimal197@gmail.com>
2024-09-18 14:52:54 -05:00
Vivek Khandelwal 3f46348e8e
build: manually update PyTorch version (#3715)
Set PyTorch and TorchVision version to nightly release 2024-09-16.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-18 12:00:15 +05:30
justin-ngo-arm 14ef05a292
[TOSA] Extend Torch to TOSA reduction ops legalization (#3710)
- Add Torch to TOSA legalization for the following reduction ops:
  + aten.min.dim
  + aten.min
  + aten.max
  + aten.prod
  + aten.prod.dim_int
  + aten.all.dim
- Add dtype casting support for reduce sum and prod ops
- Extend aten.max.dim legalization to a template to support aten.min.dim
legalization
- Update end-to-end tests sets in xfail_sets.py

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-16 12:40:24 -07:00
Yuanqiang Liu 7b94ced39a
[Stablehlo] fix aten compare ops' promote rules (#3709)
previous PR(https://github.com/llvm/torch-mlir/pull/3702)
2024-09-13 18:48:41 +08:00
zjgarvey d61986cfcf
Add Decompostion for `Aten_SafeSoftmaxOp` (#3708)
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-12 16:58:10 -05:00
yyp0 edf725ef42
[Torch] add AtenAsStridedOp in torch dialect (#3706) 2024-09-12 19:07:11 +08:00
Branko Trifkovic 1c4b9d6a0e
Implement lowering of torch.aten.hstack (#3563) 2024-09-11 16:41:47 +05:30
penguin_wwy 04740824ae
[ci] enable fx_importer2stablehlo ci test (#3698) 2024-09-11 09:53:23 +08:00
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
Srinath Avadhanula 0a788e0467
Decompose aten.fmod into aten.mul,sub,div etc. (#3689)
As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.

This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-09 09:00:11 -07:00
Branko Trifkovic 70d5730c87
[LINALG] Implement lowering of torch.aten.rot90 (#3551) 2024-09-06 10:36:17 +05:30
justin-ngo-arm d4b5e05ac1
[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)
Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-05 11:27:29 -07:00
zjgarvey 295bf418a4
Add a canonicalization pattern for `aten.unflatten.int` (#3656)
Addresses an issue in <https://github.com/llvm/torch-mlir/issues/3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.

This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
2024-09-03 16:38:20 -07:00
Ze Zhang b3942ff984
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
2024-09-03 09:13:59 -07:00
Vivek Khandelwal 567ed44fd0
[MLIR][TORCH] Add E2E support for aten.polar op (#3671)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-03 10:51:03 +05:30
lingzhiz1998 5bc59ce1fa
[TorchToLinalg] Support lowering MaxPool3dWithIndices (#3652)
Support torch.MaxPool3dWithIndices lowering to linalg backend.
2024-08-27 14:14:25 -05:00
Vivek Khandelwal b92e61832f
build: manually update PyTorch version (#3666)
Set PyTorch and TorchVision version to nightly release 2024-08-25.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-27 09:28:30 -07:00
penguin_wwy 6eba5bc9ee
[Torch] Extract TensorPlaceholder to a common interface (#3668) 2024-08-27 23:31:28 +08:00
Rob Suderman 9a4c8c606c
[torch] Add `torch.aten.view.dtype` to op list (#3664)
Support dtype conversion between types. This is useful for bitcasting
buffers between differing bit depths.
2024-08-23 19:02:53 -07:00
Xida Ren (Cedar) 4358aaccd6
Add per-test timeouts to catch infinite loops (#3650)
Previously we only had full suite timeouts, making it impossible to
identify
which specific tests were hanging. This patch adds:

1. Per-test timeout support in the test framework
2. A default 600s timeout for all tests
3. A deliberately slow test to verify the timeout mechanism works

The timeout is implemented using Python's signal module. Tests that
exceed
their timeout are marked as failures with an appropriate error message.

This should help catch and isolate problematic tests that enter infinite
loops, without needing to re-run the entire suite multiple times.
2024-08-21 11:37:31 -07:00
lingzhiz1998 7f886cc270
[TorchToLinalg] Support torch.isclose lower to linalg (#3631) 2024-08-21 11:55:54 +08:00
zjgarvey f66908f190
[TorchToLinalg] address a dtype mismatch in `aten.multinomial` lowering (#3630)
Resolves <https://github.com/llvm/torch-mlir/issues/3628>
Unblocks a compile failure for one of the MiGraphx models
(`AgentModel`).
2024-08-20 15:14:48 -05:00
Vivek Khandelwal 0a86deb59a
build: manually update PyTorch version (#3627)
Set PyTorch and TorchVision version to nightly release 2024-08-18.
This commit also updates the `scaled_dot_product_attention` op. 
A new attribute `enable_gqa` has been added. As of now, only the
default value for the same is supported.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-19 12:03:56 +05:30
yyp0 43e3118eb9
[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592) 2024-08-15 20:06:29 +08:00
pkapris-syrmia 23ec5399e5
Implement lowering of aten.atleast_2d (#3546)
This operator is needed to implement aten.vstack, which will be
submitted in a subsequent PR
2024-08-14 18:52:31 +05:30
Branko Trifkovic da877a781e
Added support for integer to complex conversion (#3604) 2024-08-14 18:13:00 +05:30
pkapris-syrmia 10fe5d08d1
Implement lowering for torch.aten.rad2deg (#3586) 2024-08-14 16:37:28 +05:30
Rob Suderman 9ab93436c4
[torch] Support diagonal `einsum.Diagonal` (#3618)
The einsum lowering was missing the behavior for duplicate indices in
the equation. This amounts to a diagonalization along duplicate pairs of
indices in the equation.
2024-08-13 09:38:43 -07:00
pkapris-syrmia d11d6f6fea
[TorchToLinalg] Fix torch.aten.remainder for negative operands (#3581)
Closes #3575

The PyTorch remainder operator is meant to compute the Python modulus
operator entrywise:

https://pytorch.org/docs/stable/generated/torch.remainder.html#torch.remainder

In python the modulus operator is meant to always return a result with
the same sign as the divisor:

https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations

In other words, torch.aten.remainder should return a Python-style
modulus instead of a C-style modulus. However the remainder operator was
simply translated into arith.ModSI or arith.ModF, which both effectively
compute the C-style modulus. Now the lowering has been modified so that
the modulus operator works properly with negative numbers, both in the
dividend, and the divisor.
2024-08-13 21:17:21 +05:30
Yuanqiang Liu c5b3cf299a
[Torch] emit upsample_nearest1d/2d/vec, and add shape/dtype functions (#3629) 2024-08-13 19:14:24 +08:00
Matthias Gehre 334633b738
e2e: Enable generate-runtime-verification pass (#3615)
This adds the `generate-runtime-verification` pass into the linalg
refbackend, and moves all tests that now abort at runtime into the crash
set, sorted by their respective errors.

I have fixed on set of errors found that way, which are mismatches
between the static dimensions we cast to and the actual dynamic
dimensions. This was caused by wrong annotations on the test cases, like
in
https://github.com/llvm/torch-mlir/pull/3615/files#diff-48bfbf41fcad5fa01b49197d251114f84a2b8de4f1d87ab938a061aedd1419b1R1931
2024-08-12 14:15:12 +02:00
Felix Schneider 0314188dbe
[torch] Basic support for per-channel quantized graphs (#3623)
This patch adds basic support for lowering graphs with per-channel
quantization. Per-channel quantized ops have to be excluded from
`FuseQuantizedOps` for now but can be used in QDQ quantized form.

Using this patch, we're able to import and execute (on the linalg
backend) graphs with per-channel quantization applied using the "new"
PyTorch 2.0 Export Quantization.
2024-08-10 15:51:09 +02:00
Vivek Khandelwal f91f816336
Bump llvm to 585523750e2bbe374d1cb3bf4ff9d53de29b9593 (#3613)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-09 00:36:10 +08:00
zjgarvey 8d95fe9eeb
[TorchToArith] Add a lowering for `torch.add.float_int` (#3594) 2024-08-07 11:55:27 -05:00
Branko Trifkovic 2d6bfb2dec
[LINALG] Added support for conversion from float to complex. (#3595) 2024-08-07 12:36:48 +05:30