Commit Graph

1720 Commits (5f74de5ba0cd9fcb8d5af75a38de5899d3875de6)

Author SHA1 Message Date
Yuanqiang Liu 5f74de5ba0
[Stablehlo] support aten.all.dim (#3746) 2024-09-30 15:59:27 +08:00
yyp0 eb4e59e189
[Torch] support binary_cross_entropy_with_logits decomposition (#3741) 2024-09-29 17:41:20 +08:00
jinchen a33d1232c5
[onnx] Fix onnx.Shape lowering with scalar input (#3716)
Address https://github.com/nod-ai/SHARK-Turbine/issues/826
2024-09-27 13:30:02 -07:00
Xida Ren (Cedar) 9938abf25e
AtenCumprodOp (#3737) 2024-09-26 18:17:22 -04:00
yyp0 335cf5f6d0
[stablehlo] support aten_adaptive_max_pool1d lowering (#3728) 2024-09-26 11:42:38 +08:00
Xida Ren (Cedar) aa7e77ee64
Better errmsg upon getScalarTypeForType failure (#3734)
Instead of 
`Unhandled type in getScalarTypeForType`

You now get

Unhandled type in getScalarTypeForType: (type name)
Type properties:
  Is integer: yes
  Bit width: 
...


The root cause is https://github.com/llvm/torch-mlir/issues/3720, at
least for unsigned integer issues.
2024-09-25 16:32:26 +00:00
Vinayak Dev 67732883fa
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730)
Fixes https://github.com/iree-org/iree/issues/18562.

During canonicalization pass on `AtenUnflattenIntOp`, if the second dim
was statically equal to one, we would create an `AtenAddIntOp` to add
one to the dimension obtained from `op.getDim()`. This, when passed into
`Torch::unsqueezeTensor()`, would make it get interpreted as
non-constant, which would lead to MLIR failing an assertion when
`UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the
output of the `UnsqueezeOp` would consist of only dynamic dims.

This patch fixes this behavior, by extracting the integer value from the
dim if it was constant, and then emitting a `ConstantIntOp` from
(dim+1). This creates an output with static shape.
2024-09-24 11:45:18 -05:00
giacs-epic 99848265c3
[onnx] Relax constraints on input tensors in `onnx.STFT` conversion to torch dialect (#3676)
- When the signal tensor is real, onnx allows its shape to be
`[batch][length]` as well as `[batch][length][1]`.
- Onnx also allows to specify `frame_length` together with `window` (not
empty), given that it matches the window size.
- Adding checks on signal and result shapes.
2024-09-23 12:09:29 +05:30
Justin Ngo 3f79a2982a
[TOSA] Extend Torch to TOSA legalization coverage (#3718)
- Add Torch to TOSA legalization for the following ops:
  + aten.logical_not
  + aten.logical_xor
  + aten.cos
  + aten.sin
  + aten.pow.Scalar
  + aten.pow.Tensor_Tensor
  + aten.erf
  + aten.bitwise_and.Scalar
  + aten.bitwise_left_shift.Tensor
  + aten.bitwise_right_shift.Tensor
  + aten.le.Tensor
  + aten.le.Scalar
- Update e2e tests in xfail_sets
- Update basic.mlir with newly legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I4aa5790073ef2e5ec0e9b374da42887242f8dabc

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-20 14:33:55 -07:00
Justin Ngo abaff58c6d
[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops support (#3717)
- Add legalization for aten.div rounding mode:
  + trunc: rounds division results towards zero
  + floor: rounds division results down
- Add legalization for aten.remainder.Scalar and aten.fmod ops
- Add legalization for aten.ge.Tensor op
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-20 13:34:09 -07:00
Rob Suderman 5ce48dfacd
[torch] Fix attention on linalg for dynamic shapes (#3714)
Current version does not work for a mixture of dynamic and static shaped
batch dimensions. Rework to grab the correct dynamic shapes.

---------

Co-authored-by: dan <danimal197@gmail.com>
2024-09-18 14:52:54 -05:00
zjgarvey d2c387dd04
[ONNX] Fix issue with absent value in onnx.ConstantOfShape (#3713)
Previously, if the value was absent, this conversion was creating a
dense resource of value 0 with shape equal to the result shape, then
later re-extracting a splat value. This only works if the shape is
statically known, and even when the shape is known, this is completely
unnecessary since the value's shape should be `[1]` and not the result
shape.

This patch simply sets the `splatvalue` to a `torch.constant.float 0.0`
when the onnx op's `value` attr is absent, and adds `nullptr` checks to
the subsequent conditionals to avoid them in the case where an `attr` is
not given.

Addresses <https://github.com/nod-ai/SHARK-Turbine/issues/831>.
2024-09-17 16:01:01 -05:00
justin-ngo-arm 14ef05a292
[TOSA] Extend Torch to TOSA reduction ops legalization (#3710)
- Add Torch to TOSA legalization for the following reduction ops:
  + aten.min.dim
  + aten.min
  + aten.max
  + aten.prod
  + aten.prod.dim_int
  + aten.all.dim
- Add dtype casting support for reduce sum and prod ops
- Extend aten.max.dim legalization to a template to support aten.min.dim
legalization
- Update end-to-end tests sets in xfail_sets.py

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-16 12:40:24 -07:00
Srinath Avadhanula bc70c50373
Delete unnecessary linalg conversion for aten.fmod (#3707)
Follow up cleanup for [this
PR](https://github.com/llvm/torch-mlir/pull/3689), which introduced a
decomposition for `aten.fmod.Tensor`. This means that the lowering for
this operator in linalg is no longer needed.

Thanks to @vivekkhandelwal1 for pointing this out.

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-13 09:39:58 -07:00
Yuanqiang Liu 7b94ced39a
[Stablehlo] fix aten compare ops' promote rules (#3709)
previous PR(https://github.com/llvm/torch-mlir/pull/3702)
2024-09-13 18:48:41 +08:00
zjgarvey d61986cfcf
Add Decompostion for `Aten_SafeSoftmaxOp` (#3708)
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-12 16:58:10 -05:00
yyp0 edf725ef42
[Torch] add AtenAsStridedOp in torch dialect (#3706) 2024-09-12 19:07:11 +08:00
Yuanqiang Liu 3f07077ff9
[Torch] enhance fold of aten.alias (#3705) 2024-09-12 17:04:57 +08:00
Branko Trifkovic 1c4b9d6a0e
Implement lowering of torch.aten.hstack (#3563) 2024-09-11 16:41:47 +05:30
Rob Suderman 6934ab81b0
Bump llvm/llvm-project@b6603e1bf1 (#3697)
Bump forward and refactor inline global slots to no longer track via
symlinks. This appears to make the tests past until we manage to remove
torchscript work.
2024-09-10 08:57:15 -07:00
giacs-epic b35675a78e
[onnx] Add support for `auto_pad` in `onnx.Conv` (#3670)
Add logic for `auto_pad` attribute in the conversion of `onnx.Conv`
torch dialect.
Add lit tests covering different configurations of `auto_pad`.
2024-09-10 20:31:53 +05:30
rohan-tan-bhowmik e86f56bc76
[Torch] [TMTensor] Added mask and is_causal support for torch.aten.scaled_dot_product_attention (#3690)
Enabled mask and is_causal parameters for torch.aten.scaled_dot_product
attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR,
including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.
2024-09-09 15:51:41 -07:00
Srinath Avadhanula 0a788e0467
Decompose aten.fmod into aten.mul,sub,div etc. (#3689)
As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.

This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-09 09:00:11 -07:00
Felix Schneider df6098e43d
[TorchToLinalg] Use `linalg.transpose` instead of `generic` when lowering `aten.T` (#3660)
The lowering pattern for `aten.T` uses transposition implemented via
`linalg.generic`. For downstream passes it is advantageous to use named
ops wherever possible, so this patch changes the lowering to use
`linalg.transpose` instead.
2024-09-07 08:09:10 +02:00
Branko Trifkovic 70d5730c87
[LINALG] Implement lowering of torch.aten.rot90 (#3551) 2024-09-06 10:36:17 +05:30
justin-ngo-arm d4b5e05ac1
[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)
Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-05 11:27:29 -07:00
zjgarvey 295bf418a4
Add a canonicalization pattern for `aten.unflatten.int` (#3656)
Addresses an issue in <https://github.com/llvm/torch-mlir/issues/3651>
where some unflatten ops generated from onnx models weren't propagating
static shape information. It may be necessary to add further
optimizations for the more general case when some static information is
present in the unflatten (or possibly reshape/view) op's `sizes` list,
but not reflected in the output shape. These ops will only successfully
infer shapes if the `sizes` list is gotten from a list of constant ints
(with possibly one -1). A common example where this fails is when some
of the `sizes` are determined from `aten.size.int` ops on dynamic
tensors, and other `sizes` are known statically.

This PR includes:
- a canonicalizer for `aten.unflatten.int` which converts to
`aten.unsqueeze` when it is expanding one dim to two, and one of the new
dims is statically 1.
- an improvement to the folder for `aten.__or__.bool` which does not
rely on *both* operands being static.
2024-09-03 16:38:20 -07:00
Ze Zhang b3942ff984
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
2024-09-03 09:13:59 -07:00
Vivek Khandelwal 567ed44fd0
[MLIR][TORCH] Add E2E support for aten.polar op (#3671)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-09-03 10:51:03 +05:30
jinchen fd759e4b1f
Fix onnx.Gather lowering with dynamic shapes (#3675)
Supports the result with dynamic shape and scalar indices like
```
func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
  %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32>
  return %0 : !torch.vtensor<[?,?],f32>
}
```

`Torch::AtenSqueezeOp` is referring to the result shape, so it will
failed on lowering if the result shape is dynamic.
2024-08-29 17:02:16 -07:00
lingzhiz1998 5bc59ce1fa
[TorchToLinalg] Support lowering MaxPool3dWithIndices (#3652)
Support torch.MaxPool3dWithIndices lowering to linalg backend.
2024-08-27 14:14:25 -05:00
Xida Ren (Cedar) eb7bf78a9c
Add RestructureNonConstantAxes pass to address reduce op tests failing on non constant axes (#3600) 2024-08-26 14:06:06 -07:00
Felix Schneider 638ef14512
[TorchToLinalg] Use `linalg.broadcast` instead of `generic` for conv bias (#3661)
The current implementation uses a `linalg.generic` to broadcast the bias
tensor for the lowering of convolutions. This is suboptimal for later
pattern matching. This patch changes it to use the respective named op,
`linalg.broadcast`, instead.
2024-08-26 20:29:11 +02:00
Rob Suderman f9766c89f6
[onnx] Handle `torch.aten` for inner product case (#3634)
The following case was failing to lower for einsum. This fixes up the
inner product issue.
2024-08-24 11:41:25 -07:00
Rob Suderman 6cf139687d
[onnx] Support for optional `axis` attribute for `onnx.Pad` (#3635)
The `axis` attribute is optionally available. Added support by computing
the pad based on the axis values.

---------

Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
2024-08-24 11:41:08 -07:00
Rob Suderman b3b8e2e96a
[torch] Fix lowerings of rshift and lshift (#3665)
I missed adding second operand conversion and adding them to the set of
rewrite patterns.
2024-08-24 03:27:18 +00:00
Rob Suderman 9a4c8c606c
[torch] Add `torch.aten.view.dtype` to op list (#3664)
Support dtype conversion between types. This is useful for bitcasting
buffers between differing bit depths.
2024-08-23 19:02:53 -07:00
Phaneesh Barwaria 9a6fe58a02
onnx.MelWeightMatrix Onnx to Torch to Linalg (#3659)
- This PR adds new (and equivalent) more tensorized impl of
MelWeightMatrix which lowers all the way to linalg.
- [Ref Pytorch
Impl](https://gist.github.com/PhaneeshB/4e6dfcded3007b1b686fbe28f07a67cd)
- Thanks to @rsuderman for pointing out the difficulties [earlier
impl](#3503) posed during lowering to linalg and also for providing a
better numpy impl 🙏
2024-08-22 08:55:03 -07:00
Vivek Khandelwal fcc5f444cd
MLIR][TORCH] Fix GroupNorm decomposition by adding shape info (#3658)
This commit adds the shape info for the tensors created during the
decomposition of GroupNorm op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-22 21:20:40 +05:30
lingzhiz1998 7f886cc270
[TorchToLinalg] Support torch.isclose lower to linalg (#3631) 2024-08-21 11:55:54 +08:00
Ian Wood a24114efa3
[TorchToLinalg] remove `extract_slice` grid_sample lowering (#3483)
Instead of using extract_slice for grid sampler, use affine constants to access the X and Y values in the generic op's region.
2024-08-20 14:23:43 -07:00
zjgarvey f66908f190
[TorchToLinalg] address a dtype mismatch in `aten.multinomial` lowering (#3630)
Resolves <https://github.com/llvm/torch-mlir/issues/3628>
Unblocks a compile failure for one of the MiGraphx models
(`AgentModel`).
2024-08-20 15:14:48 -05:00
Vivek Khandelwal 0a86deb59a
build: manually update PyTorch version (#3627)
Set PyTorch and TorchVision version to nightly release 2024-08-18.
This commit also updates the `scaled_dot_product_attention` op. 
A new attribute `enable_gqa` has been added. As of now, only the
default value for the same is supported.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-19 12:03:56 +05:30
Rob Suderman 78deb175b3
[onnx] Fix shortcircuit path (#3633)
The implementation was short circuiting the second result. Updated to
guarantee we do not short circuit.
2024-08-16 09:23:47 -07:00
Rob Suderman 3a599bec80
[onnx] Fix onnx.ThresholdedRelu crash (#3638)
Result type was not fetched causing a crash on construction
2024-08-16 09:23:38 -07:00
Rob Suderman f09cb766dc
[onnx] Fix `torch` lowering for determinant (#3639)
The determinant lowering had some extract / insert shape mismatches.
Replumbed shape manipulations to correctly implement the determinant
operation.
2024-08-15 15:41:50 -07:00
yyp0 43e3118eb9
[Stablehlo] use stablehlo specs lowering AtenSliceScatterOp (#3592) 2024-08-15 20:06:29 +08:00
Yevhenii Havrylko 64b0d4aed3
Add missing dependency to TorchMLIRRefBackend target (#3107)
Discovered in https://github.com/llvm/torch-mlir/issues/3104
Most likely when building with stablehlo, while waiting for it missing
dependency was generated to location shared with another dependency.
2024-08-14 23:41:51 +08:00
pkapris-syrmia 23ec5399e5
Implement lowering of aten.atleast_2d (#3546)
This operator is needed to implement aten.vstack, which will be
submitted in a subsequent PR
2024-08-14 18:52:31 +05:30
Branko Trifkovic da877a781e
Added support for integer to complex conversion (#3604) 2024-08-14 18:13:00 +05:30