Commit Graph

3218 Commits (a82ba1c42282bff79d7b5f0a1c25601d265bc7f2)
 

Author SHA1 Message Date
Vivek Khandelwal f6721e5999
[MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (#3763)
This commit adds the support for negative step values in
aten.slice.Tensor op. Although, PyTorch does not allow negative step
value for slice op but the Onnx.Slice op supports negative step value
which eventually lowers to torch.aten.slice.Tensor op. Hence, the
support is added for handling those kind of values during the
Torch->Linalg lowering of aten.slice.Tensor op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-08 10:34:27 +05:30
Justin Ngo b08d08682f
[TOSA] Add legalization for fill, flip, and round (#3768)
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and
aten.round
- Fix torchScalarToTosaTensor function to correctly convert Torch scalar
input to TOSA tensor
- Update xfail_sets.py with new e2e results
- Update basic.mlir with LIT tests for new ops


Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-07 10:28:26 -07:00
Chi_Liu f4840ed886
[ONNX] Fix onnx.ScatterElements with AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext dialect (#3754)
- To fix issue onnx.ScatterElements: https://github.com/nod-ai/SHARK-ModelDev/issues/823
- E2E test: https://github.com/nod-ai/SHARK-TestSuite/pull/363
2024-10-05 22:22:41 -07:00
Rob Suderman 53f7532e76
Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767)
Reverted due to downstream model changes. Will reland with fixes post
integration.

This reverts commit 6e8c7bed4b.
2024-10-04 14:48:02 -07:00
Justin Ngo e9ed4af9ce
[TOSA] Add legalization for aten.index_select (#3760)
- Add Torch to TOSA legalization for aten.index_select
- Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to
correctly convert Torch indices to TF-style indices, which is used in
convertGatherNdOp
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new LIT test for aten.index_select

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-04 12:24:22 -07:00
Rob Suderman 2374b9e02d
Bump to llvm/llvm-project@e813750354 (#3765)
Includes stablehlo bump
2024-10-04 12:08:35 -07:00
zjgarvey 6e8c7bed4b
[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)
This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <https://github.com/iree-org/iree/issues/18677>
2. <https://github.com/iree-org/iree/issues/18631>
2024-10-04 11:27:00 -05:00
zjgarvey f08bfc4ff8
[ONNX] simplify shapes fed to broadcast in Expand lowering (#3756)
Addresses ~200 onnx model compile failures in
<https://github.com/nod-ai/SHARK-TestSuite> related to
<https://github.com/iree-org/iree/issues/18631>.

This change simplifies the result of the generated broadcast op
substantially, but reduces the case coverage slightly.

The case which will become unsupported: 
- trying to actually broadcast a dynamic dim that is secretly 1. 

When does this case appear in practical scenarios?
- for a model where onnx shape inference cannot figure out that a dim
should be 1.

Why do I think we should not support this case for now?
1. For all models with dynamic dim expand ops, the previous path
uniformly generates uglier linalg IR (making it harder for IREE to fuse
properly with other ops).
2. For models failing shape inference castastrophically enough to fail
to see a dim is statically 1, we can try to apply constant folding in
the onnx model before importing.

Leaving this as a draft PR, since it may be more appropriate to fix the
compilation failure in IREE rather than torch-mlir.

### Example of broadcast required in previous path:

```mlir
    %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %306 = linalg.index 0 : index
      %307 = linalg.index 3 : index
      %308 = arith.index_cast %285 : i64 to index
      %309 = arith.cmpi eq, %308, %c1 : index
      %310 = arith.select %309, %c0, %306 : index
      %311 = arith.index_cast %286 : i64 to index
      %312 = arith.cmpi eq, %311, %c1 : index
      %313 = arith.select %312, %c0, %307 : index
      %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_79 : i1
    } -> tensor<?x12x?x?xi1>
```

### Example of broadcast with simplified shape list:

```mlir
    %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) {
    ^bb0(%in: i1, %out: i1):
      linalg.yield %in : i1
    } -> tensor<?x12x?x?xi1>
```
2024-10-03 20:11:51 -05:00
Rob Suderman 9ab0db5789
[torch] `torch.aten.complex` operation with lowering (#3738)
Add the operation with lowering to linalg. Includes a test for
end-to-end correctness.
2024-10-03 11:09:52 -07:00
Kyle Wang f0b7ca72f5
Fixed GRU quality issues exposed by e2e tests (#3753)
Issue: https://github.com/nod-ai/SHARK-ModelDev/issues/856

Related tests:
![Screenshot 2024-10-01
175305](https://github.com/user-attachments/assets/0dc0901b-058f-427c-a596-9e806fd38836)
2024-10-02 17:00:19 -04:00
Sambhav Jain f8e4a9a3c2
[Release] Fix binary name for downstream compatibility (#3752)
As of Sep 14, the torch-mlir binary
[wheels](https://github.com/llvm/torch-mlir-release/releases/tag/dev-wheels)
got renamed to `torch-mlir-core` from `torch-mlir`:
![image](https://github.com/user-attachments/assets/152e4977-71ef-4f57-8757-6dc75f72b670)

This was an unintended side-effect of the recent change of
`TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=True`
(https://github.com/llvm/torch-mlir/pull/3711) which skips setting `NAME
= "torch-mlir"` in
[setup.py](https://github.com/llvm/torch-mlir/blob/main/setup.py#L226-L232).

To avoid having multiple downstreams fix their pip deps, this change
allows using the same `torch-mlir` name for binaries, and reserves a
separate `torch-mlir-ext` name for the (less popular) binaries with
extensions enabled.
2024-10-02 11:52:20 -07:00
Samu Tamminen a2bfe47faa
[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)
Address https://github.com/nod-ai/SHARK-Turbine/issues/833
2024-10-02 08:17:58 -05:00
Prathamesh Tagore 617c1c76ce
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)
The op can be valid with no attached shape symbols if they are not
required by the corresponding affine map. Fix the verifier to consider
number of arguments for both.
2024-10-02 05:55:54 -07:00
Marius Brehler b1413a6c7f
Update instructions on creating a virtual env (#3724)
The `python` command is only available on Ubuntu if the
`python-is-python3` package is installed, see
https://packages.ubuntu.com/jammy/python-is-python3 and
https://packages.ubuntu.com/jammy/all/python-is-python3/filelist. As
Python 2 isn't supported anyway, it's safe to point to `python3` here
instead.
2024-10-01 19:12:11 +02:00
Justin Ngo 5eab669c4a
[TOSA] Add legalization for aten.diagonal (#3740)
- Add lowering from Torch to TOSA for aten.diagonal
- Clean up some code
- Update xfail_sets.py with the new e2e results
- Update basic_mlir with the new op mlir test

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I99bed685455752d09ed96edd837c4dfbee152701

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-30 08:24:31 -07:00
Yuanqiang Liu 5f74de5ba0
[Stablehlo] support aten.all.dim (#3746) 2024-09-30 15:59:27 +08:00
yyp0 eb4e59e189
[Torch] support binary_cross_entropy_with_logits decomposition (#3741) 2024-09-29 17:41:20 +08:00
jinchen a33d1232c5
[onnx] Fix onnx.Shape lowering with scalar input (#3716)
Address https://github.com/nod-ai/SHARK-Turbine/issues/826
2024-09-27 13:30:02 -07:00
Xida Ren (Cedar) 9938abf25e
AtenCumprodOp (#3737) 2024-09-26 18:17:22 -04:00
yyp0 335cf5f6d0
[stablehlo] support aten_adaptive_max_pool1d lowering (#3728) 2024-09-26 11:42:38 +08:00
Xida Ren (Cedar) aa7e77ee64
Better errmsg upon getScalarTypeForType failure (#3734)
Instead of 
`Unhandled type in getScalarTypeForType`

You now get

Unhandled type in getScalarTypeForType: (type name)
Type properties:
  Is integer: yes
  Bit width: 
...


The root cause is https://github.com/llvm/torch-mlir/issues/3720, at
least for unsigned integer issues.
2024-09-25 16:32:26 +00:00
Vinayak Dev 67732883fa
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730)
Fixes https://github.com/iree-org/iree/issues/18562.

During canonicalization pass on `AtenUnflattenIntOp`, if the second dim
was statically equal to one, we would create an `AtenAddIntOp` to add
one to the dimension obtained from `op.getDim()`. This, when passed into
`Torch::unsqueezeTensor()`, would make it get interpreted as
non-constant, which would lead to MLIR failing an assertion when
`UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the
output of the `UnsqueezeOp` would consist of only dynamic dims.

This patch fixes this behavior, by extracting the integer value from the
dim if it was constant, and then emitting a `ConstantIntOp` from
(dim+1). This creates an output with static shape.
2024-09-24 11:45:18 -05:00
Marius Brehler e4f2bdf0db
Document requirements for `torch_mlir_e2e_test` (#3722)
This documents which CMake options must be set to be able to use
`torch_mlir_e2e_test`, required e.g. for
`projects/pt1/tools/e2e_test.sh`.

Makes progress on #3696.
Closes #3719.
2024-09-24 00:06:54 +02:00
giacs-epic 99848265c3
[onnx] Relax constraints on input tensors in `onnx.STFT` conversion to torch dialect (#3676)
- When the signal tensor is real, onnx allows its shape to be
`[batch][length]` as well as `[batch][length][1]`.
- Onnx also allows to specify `frame_length` together with `window` (not
empty), given that it matches the window size.
- Adding checks on signal and result shapes.
2024-09-23 12:09:29 +05:30
Justin Ngo 3f79a2982a
[TOSA] Extend Torch to TOSA legalization coverage (#3718)
- Add Torch to TOSA legalization for the following ops:
  + aten.logical_not
  + aten.logical_xor
  + aten.cos
  + aten.sin
  + aten.pow.Scalar
  + aten.pow.Tensor_Tensor
  + aten.erf
  + aten.bitwise_and.Scalar
  + aten.bitwise_left_shift.Tensor
  + aten.bitwise_right_shift.Tensor
  + aten.le.Tensor
  + aten.le.Scalar
- Update e2e tests in xfail_sets
- Update basic.mlir with newly legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-20 14:33:55 -07:00
Justin Ngo abaff58c6d
[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops support (#3717)
- Add legalization for aten.div rounding mode:
  + trunc: rounds division results towards zero
  + floor: rounds division results down
- Add legalization for aten.remainder.Scalar and aten.fmod ops
- Add legalization for aten.ge.Tensor op
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-20 13:34:09 -07:00
Rob Suderman 5ce48dfacd
[torch] Fix attention on linalg for dynamic shapes (#3714)
Current version does not work for a mixture of dynamic and static shaped
batch dimensions. Rework to grab the correct dynamic shapes.

---------

Co-authored-by: dan <danimal197@gmail.com>
2024-09-18 14:52:54 -05:00
Vivek Khandelwal 3f46348e8e
build: manually update PyTorch version (#3715)
Set PyTorch and TorchVision version to nightly release 2024-09-16.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-18 12:00:15 +05:30
zjgarvey d2c387dd04
[ONNX] Fix issue with absent value in onnx.ConstantOfShape (#3713)
Previously, if the value was absent, this conversion was creating a
dense resource of value 0 with shape equal to the result shape, then
later re-extracting a splat value. This only works if the shape is
statically known, and even when the shape is known, this is completely
unnecessary since the value's shape should be `[1]` and not the result
shape.

This patch simply sets the `splatvalue` to a `torch.constant.float 0.0`
when the onnx op's `value` attr is absent, and adds `nullptr` checks to
the subsequent conditionals to avoid them in the case where an `attr` is
not given.

Addresses <https://github.com/nod-ai/SHARK-Turbine/issues/831>.
2024-09-17 16:01:01 -05:00
justin-ngo-arm 14ef05a292
[TOSA] Extend Torch to TOSA reduction ops legalization (#3710)
- Add Torch to TOSA legalization for the following reduction ops:
  + aten.min.dim
  + aten.min
  + aten.max
  + aten.prod
  + aten.prod.dim_int
  + aten.all.dim
- Add dtype casting support for reduce sum and prod ops
- Extend aten.max.dim legalization to a template to support aten.min.dim
legalization
- Update end-to-end tests sets in xfail_sets.py

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-16 12:40:24 -07:00
saienduri d6cf718f10
Redefine TorchMLIRPythonModules to avoid building empty libraries. (#3711)
Trying to build empty libraries causes weird failures based on clang/gcc
and doesn't work with certain versions of python as well. We should
avoid this wherever possible, and this specifically has been leading to
the following issues/failures:

https://github.com/llvm/torch-mlir/issues/3663

https://github.com/llvm/torch-mlir-release/actions/runs/10558518843/job/29248139823
2024-09-13 10:41:34 -07:00
Srinath Avadhanula bc70c50373
Delete unnecessary linalg conversion for aten.fmod (#3707)
Follow up cleanup for [this
PR](https://github.com/llvm/torch-mlir/pull/3689), which introduced a
decomposition for `aten.fmod.Tensor`. This means that the lowering for
this operator in linalg is no longer needed.

Thanks to @vivekkhandelwal1 for pointing this out.

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-13 09:39:58 -07:00
Yuanqiang Liu 7b94ced39a
[Stablehlo] fix aten compare ops' promote rules (#3709)
previous PR(https://github.com/llvm/torch-mlir/pull/3702)
2024-09-13 18:48:41 +08:00
zjgarvey d61986cfcf
Add Decompostion for `Aten_SafeSoftmaxOp` (#3708)
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-12 16:58:10 -05:00
yyp0 edf725ef42
[Torch] add AtenAsStridedOp in torch dialect (#3706) 2024-09-12 19:07:11 +08:00
Yuanqiang Liu 3f07077ff9
[Torch] enhance fold of aten.alias (#3705) 2024-09-12 17:04:57 +08:00
Rob Suderman bb69014a96
bump llvm/llvm-project@d418a03e01 (#3700)
Forward bump llvm dependency to current head
2024-09-11 11:42:31 -04:00
Branko Trifkovic 1c4b9d6a0e
Implement lowering of torch.aten.hstack (#3563) 2024-09-11 16:41:47 +05:30
penguin_wwy 04740824ae
[ci] enable fx_importer2stablehlo ci test (#3698) 2024-09-11 09:53:23 +08:00
Rob Suderman 6934ab81b0
Bump llvm/llvm-project@b6603e1bf1 (#3697)
Bump forward and refactor inline global slots to no longer track via
symlinks. This appears to make the tests past until we manage to remove
torchscript work.
2024-09-10 08:57:15 -07:00
giacs-epic b35675a78e
[onnx] Add support for `auto_pad` in `onnx.Conv` (#3670)
Add logic for `auto_pad` attribute in the conversion of `onnx.Conv`
torch dialect.
Add lit tests covering different configurations of `auto_pad`.
2024-09-10 20:31:53 +05:30
Vivek Khandelwal b5d95ff399
build: manually update PyTorch version (#3692)
Set PyTorch and TorchVision version to nightly release 2024-09-09.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-10 16:02:28 +05:30
Chi_Liu fbb0db17dc
Disable TORCH_MLIR_ENABLE_JIT_IR_IMPORTER and TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS by default (#3693)
Only enable it in CI and debug for update_abstract_interp_lib.sh and update_torch_ods.sh usage.
2024-09-09 22:58:27 -07:00
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
Srinath Avadhanula 0a788e0467
Decompose aten.fmod into aten.mul,sub,div etc. (#3689)
As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.

This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-09 09:00:11 -07:00
Felix Schneider df6098e43d
[TorchToLinalg] Use `linalg.transpose` instead of `generic` when lowering `aten.T` (#3660)
The lowering pattern for `aten.T` uses transposition implemented via
`linalg.generic`. For downstream passes it is advantageous to use named
ops wherever possible, so this patch changes the lowering to use
`linalg.transpose` instead.
2024-09-07 08:09:10 +02:00
Branko Trifkovic 70d5730c87
[LINALG] Implement lowering of torch.aten.rot90 (#3551) 2024-09-06 10:36:17 +05:30
justin-ngo-arm d4b5e05ac1
[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)
Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-05 11:27:29 -07:00
Christopher McGirr b790061b69
[FxImporter] Add InputInfo to Resolve Literal Hook (#3688) 2024-09-06 00:53:11 +08:00
zjgarvey 295bf418a4
Add a canonicalization pattern for `aten.unflatten.int` (#3656)
Addresses an issue in <https://github.com/llvm/torch-mlir/issues/3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.

This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
2024-09-03 16:38:20 -07:00