Commit Graph

1792 Commits (main)

Author SHA1 Message Date
Hanumanth 30c519369e
Support default padding case for tosa::AvgPool in the presence of count_include_pad (#3868)
Essentially, as part of my earlier
[change](7f9f99c6f8)
, I didn't consider the `padding` value while erroring out for
unsupported `count_include_pad` during `torch-to-tosa` lowering for
AvgPool2d. The fix captured in this change addresses this. Please see
[issue](https://github.com/llvm/torch-mlir/issues/3862) for more details
on this.

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
2024-11-12 13:48:20 -08:00
zjgarvey cd38ecf6c2
Add Scalarization Patterns for `AtenToDtypeOp`, `AtenNegOp`, `AtenRemainderTensorOp` (#3861)
1. adds a lowering for `aten.neg.int` and `aten.remainder.int` to arith.
2. adds a scalarization pattern for `aten.neg` and
`aten.remainder.Tensor` ops.
3. improves folding of `aten.mul.int`
4. adds a scalarization pattern for `aten.to.dtype` which relies on
scalar cast ops and basic C++ casting between `double` and `int64_t`.
5. improves rank-0 case handling for `FoldAtenSplatPattern`
6. removes a bug with `aten.unflatten.int` decomposition incorrectly
generating a constant size int from a dynamic shape.
7. simplifies the dim list for `aten.unflatten.int` ops generated from
the `aten.view` canonicalization in scalarize shapes.

All of these changes were necessary to unblock
<https://github.com/iree-org/iree/issues/18899>.
2024-11-12 14:25:02 -06:00
aldesilv 889a836b3d
OnnxToTorch bicubic interpolation (#3802)
(https://github.com/nod-ai/SHARK-TestSuite/pull/391)
Repro (using SHARK TestSuite):
1. `python run.py --torchtolinalg -m cl-onnx-iree -t cubic_test`

---------

Co-authored-by: zjgarvey <zjgarvey@gmail.com>
2024-11-12 12:54:29 -06:00
Vivek Khandelwal 17c1985c4d
build: manually update PyTorch version (#3863)
This commit sets the PyTorch and TorchVision version to nightly release
2024-11-07.

This commit also updates the dtype check for the
`aten.fake_quantize_per_tensor_affine` and
`aten.fake_quantize_per_tensor_affine_cachemask` op since the op now
supports bfloat16 input.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-11-11 21:26:56 +05:30
Justin Ngo 8eb34dae78
[TOSA] Add promote type to unary ops and aten.cat lowering (#3860)
Change-Id: I2699bf9007723fe629edb1c524c10ef8142e0234

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-11-08 11:23:39 -08:00
Justin Ngo b6f04fa32b
[TOSA] Fix rsub; add clamp.Tensor, avg_pool1d, max_pool1d, prims.collapse (#3855)
- Fix aten.rsub.Scalar legalization with appropriate type casting
- Add legalization for aten.clamp.Tensor
- Resolve some unexpected test failures from PyTorch update by adding
legalization for the following ops:
  + aten.avg_pool1d
  + aten.max_pool1d
  + torch.prims.collapse
- Update xfail_sets with new e2e results
- Add new LIT tests to basic.mlir


Change-Id: I9762c7d36ca0b0f75ca68d0c71d7f5d5309a96ad

---------

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-11-07 14:09:43 -08:00
zjgarvey 8519ecc4d7
Generalize `aten.view` pattern in scalarize shapes (#3856)
Extends the existing pattern to allow finding matching dims from the
back as well as the front.
2024-11-07 15:26:07 -06:00
yyp0 7058f456b8
[Stablehlo] support aten.isfinite (#3850) 2024-11-07 16:52:39 +08:00
yyp0 dda65b196d
[Torch] support float_power and threshold ops (#3854) 2024-11-07 16:27:51 +08:00
yyp0 2f33f31724
[Torch] support AtenNllLossForwardOp decomposition (#3833) 2024-11-06 11:34:48 +08:00
Yuanqiang Liu 70e089802a
[Torch] emit and lowering frac, signbit, ldexp, copysign ops (#3851)
also fix `aten.exp2` with integer type
2024-11-06 10:21:37 +08:00
Ian Wood e88faf08ff
Create scatter op with unique indicies (#3853)
For the op `index_put_`, if accumulate == false, the behavior is
undefined if the indicies aren't unique
(https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html).
So, when converting `AtenIndexPutHackedTwinOp` to a TMTensor scatter op,
mark the indices as unique if when `accumulate == false`.

This should have no functional effect (unless users are relying on UB)
and assuming unique indices has the benefit of unlocking better
optimizations in further compiler stages.

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
2024-11-05 12:48:34 -08:00
Jiawei Wu b75d0e3f8b
[stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter (#3829)
In torch.index_put like ops, `values` is only required to be
broadcastable to `input[indices]`, rather than exact dimension match.
This patch fixes the problem by add additional
stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op.
BTW, this patch also enhance the `getBroadcastResultShape` utility in
hlo namespace.
2024-11-05 19:15:11 +08:00
Justin Ngo 4c1518d365
[TOSA] Add legalization for aten.as_strided (#3848)
- Add Torch to TOSA legalization for aten.as_strided op
- Update xfail_sets with the following:
  + New aten.as_strided results
+ Changes from this commit:
7f9f99c6f8
  + Failed tests from new PyTorch version update
- Add new LIT test to basic.mlir


Change-Id: I6f471ea116ca47f2bf9537b62950fce75a2c624f

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-11-04 09:57:59 -08:00
jinchen 6aa46967b6
Add tosa::getConstTensor with int8_t template (#3845)
Add tosa::getConstTensor with int8_t template used in
https://github.com/llvm/torch-mlir/pull/3827
2024-11-01 21:22:27 +00:00
zjgarvey 3104b66560
Fix Slice Folder OOB Crash and onnx.Shape lowering (#3843)
1. Clamps OOB start index to 0 in slice folder
2. Adds a more descriptive `emitError` in slice folder if the creation
of the `DenseElementsAttr` would fail due to a bad result shape.
3. Fixes the `onnx.Shape` lowering to default to `inputRank` for `end`
instead of `-1`. When `end==-1` the last element was missing when
slicing.
2024-11-01 15:33:21 -05:00
zjgarvey 738d45d3bb
add scalarization patterns to support dynamic pytorch pad exports (#3838)
1. Adds case handling for `aten.slice.tensor` shape inference with
negative strides. This is not technically allowed by native pytorch, but
it is useful for ONNX ingest. We were getting some incorrect shapes for
these negative strided slice ops.
2. Adds scalarization support for ops seen in pytorch pad exports to
ONNX. These are typically `aten.view` `aten.transpose.int` and
`aten.slice.Tensor` with negative strides (and rank 2).
3. Allows view op `self` to be added to the worklist conditionally,
based on whether the view op actually occurs as a middle point in a
shape computation.
2024-11-01 14:56:48 -05:00
jinchen 39d69db5ca
Cast static/dynamic shape for onnx.If branches to match result type (#3828) 2024-11-01 12:10:59 -07:00
zjgarvey 3cfb7c8df6
Add an info cast to `prims.squeeze` decomposition (#3844)
The onnx ingest sometimes has poorly propagated shape information. E.g.:

```mlir
...
    %9020 = torch.prims.squeeze %9010#1, %9019 : !torch.vtensor<[?,384,1],f32>, !torch.list<int> -> !torch.vtensor<[1,384],f32>
    return %9015, %9020 : !torch.vtensor<[1,384],f32>, !torch.vtensor<[1,384],f32>
  }
}
```

This occurred at the boundary of the onnx model
`migraphx_bert__bert-large-uncased`. Evidently, the output value tensor
info had more information than could be propagated forward. The
`PrimsSqueeze` lowering was returning a `!torch.vtensor<[?,384],f32>`
which was causing a type mismatch with the `func.return`.
2024-11-01 12:10:47 -05:00
zjgarvey a82ba1c422
[TorchToArith] add lowerings for some scalar bool binary ops (#3823)
Added lit tests since these scalar operations don't trace well through
the `fx_importer` route.

`XOR` and `NE` are equivalent binary operators, so `aten.ne.bool` is
lowered to `arith.xori`.
2024-11-01 10:40:20 -05:00
Xinyu Yang 3dbeda9082
[Stablehlo] fix template typo (#3842)
I think we should use template parameters. @yyp0 @qingyunqu
2024-11-01 21:10:38 +08:00
Hanumanth 7f9f99c6f8
Fix torchToTosa lowering for avgpool2d to handle unsupported parameters (#3822)
The existing TorchToTosa lowering logic for `torch.aten.avg_pool2d`
doesn't handle some unsupported properties well, leading to a silent
wrong answer(SWA) when we go through
`torch-backend-to-tosa-backend-pipeline.` For instance, with the
existing TOSA avgpool2d specification, we can not represent
`count_include_pad` and `divisor_override,` so during TorchToTosa
lowering, we silently ignore these properties which leads to SWA in some
cases—the fix captured in this change errors for unsupported scenarios.

For details on `count_include_pad` and `divisor_override,` please see
the below link.

https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html

---------

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
2024-11-01 08:25:59 -04:00
jinchen 032a636c35
Fix onnx.If lowering with scalar condition tensor (#3846)
Fixes
https://github.com/nod-ai/SHARK-ModelDev/issues/696#issuecomment-2442016530
2024-10-31 20:34:50 -07:00
Rob Suderman 25738b8c19
[linalg] Broadcast batch for mask on sdpa lowering (#3824)
Attention often broadcasts a mask across the batch dimension as masking
is usually performed the same across attention heads. Added this
materialization to the mask dimensions optionally.
2024-10-31 17:59:24 -07:00
Rob Suderman 5aa323dd29
[linalg] Fix torch.aten.add of `torch.bool` (#3820)
Addition of bools saturate which equates to an `or` operator. Updated to
avoid some noticed downstream failures.
2024-10-31 17:37:25 -07:00
Stephen Baione 9c1e3b8154
support `aten._trilinear` and improve `einsum` decomposition (#3784)
# Tracking
[Issue](https://github.com/nod-ai/SHARK-ModelDev/issues/848)
[TorchToLinalg Op
Support](https://github.com/nod-ai/SHARK-ModelDev/issues/347)

# Description

Aten_TrilinearOp is an implementation of a "trilinear einstein sum".
Essentially, just an einsum across 3 tensors.

There are a few inputs:
## Tensor Inputs
- i1, i2, i3 - The three input tensors for the _trilinear op.
## Expands 
These inputs allow you to unsqueeze an input tensor at the specified
dims as a pre-processing step to make the shapes compatible for the rest
of the op:
- expand1: List[int], expand2: List[int], expand3: List[int]

## sumdim
- sumdim: List[int] - After applying element wise multiplication, the
values in sumdim denote where to collapse a dimension by summing over it

## unroll_dim
- unroll_dim: int - In the PyTorch implementation, this specifies a
dimension where you could slice the input tensors, multiply and sum
them, then concatenate the results in an output tensor. This complicates
the implementation significantly, but doesn't change the result, so I
opted against it. Along with that, a previously accepted path for
solving this involved reusing the AtenEinsumOp, which also would also
ignore this input.


# Solution

After trying a bunch of more complicated approaches for it, this op
actually ended up being quite simple: [See
_trilinear](https://dev-discuss.pytorch.org/t/defining-the-core-aten-opset/1464)

`_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) *
i3.unsqueeze(expand3)).sum(sumdim)`

Wish I saw this earlier, but watcha gonna do: 🙃

## Not Reusing AtenEinsumOp
Frankly, I found multiple cases where valid inputs would have numerical
mismatches for EinsumOp, even when running tests against EinsumOp
directly. I think it has something to do with the singleton dimensions.
Will need to look into this further, but once I realized the simplified
approach, it appeared to be more reliable and much simpler.

Either way (credit to @zjgarvey), there are improvements to the einsum
op here. When I was originally trying to use the op, intermediate
tensors were being flattened properly, but then its 0th dimension was
being cast from a static dim to a dynamic dim due to integers not
folding correctly in the MLIR. Figured it's worth keeping these
improvements for future reusers of EinsumOp.

# The zero'd out dim "bug"

For some reason, if you specify a dimension in all `expands`,

```i.e. 
[expand1=[0], expand2=[0], expand3=[0]],
[expand1=[1], expand2=[1], expand3=[1]]
```

The _trilinear op would specify `0` for that dimension in the output
shape, unless it was also included in `sumdim`. This goes against the
implementation of torch.einsum:

```
>>> a, b, c = [torch.rand(1, 3, 3, 3) for i in range(3)] # Simulate expand at dim=0 for all input tensors
>>> torch.einsum('abcd,abcd,abcd->abcd', a, b, c).shape
torch.Size([1, 3, 3, 3])
```

And is just straight up incorrect mathematically. I considered
"replacing" singleton dims with zeroed out dims, but that seemed like
carrying over a bug. Instead, I included a test for the case, verified
that the singleton dimensions were handled the way that torch.einsum
handles it, instead of torch._trilinear, and xfailed it with a note as
to why.
2024-10-31 14:30:40 -05:00
yyp0 9ce2a69703
[Torch] support AtenExp2Op (#3832)
- support AtenExp2Op by decomposing it to aten.pow.scalar
- refine stablehlo pow.scalar pow.Tensor_Scalar pow.Tensor_Tensor
lowering according to https://github.com/llvm/torch-mlir/pull/2983
- Close https://github.com/llvm/torch-mlir/pull/2983
2024-10-31 19:14:05 +08:00
Justin Ngo 4dd213b042
[TOSA] Expand Torch to TOSA legalization coverage (#3827)
- Add/Extend Torch to TOSA legalization for the following ops:
  + Add aten.threshold_backward
  + Fix aten.threshold
  + Re-implement aten.broadcast_to using tosa.reshape and tosa.tile
  + Add support for rank 0 index for aten.index_select
  + Fix aten.index_put.hacked_twin
  + Add aten.uniform
  + Add aten.logical_and
- Update xfail_sets.py with new e2e results
- Add LIT tests to basic.mlir for newly added ops


Change-Id: I8910564a049d18293284fe2e55e82bc1d2cf10e3

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-30 16:26:10 -07:00
Max191 8b0bf2e293
Bump LLVM to llvm/llvm-project@6c64c8a6f3 (#3818)
- bumps llvm-project to
6c64c8a6f3
- bumps stablehlo to
6e403b1aa6
- Updates type conversion materialization functions to return Value
after API change in llvm-project.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
2024-10-30 11:38:51 -04:00
Max191 6b58c89914
Remove variable used for only assertion (#3837)
Removes a boolean variable that is used only for an assertion, and
inlines the condition into the assertion.

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
2024-10-30 10:51:06 -04:00
Yuanqiang Liu 9ab2a150f2
[Torch] emit upsample_bilinear2d(.vec) ops (#3834) 2024-10-30 20:18:24 +08:00
Sayan Saha 2b01f8b7f3
[Tosa] : Add support for negative indices in index.tensor and index.Tensor_hacked_twin for TorchToTosa lowering. (#3790)
1. Negative indices for tensor indexing is handled by wrapping around
the index values by checking their values at run time. Without the fix,
there was a runtime error.
2. Added a lit test to lock down the behavior.
3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown
the behavior with e2e test as well.

"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY."
2024-10-25 15:37:19 -07:00
Andrija Bosnjakovic 54d9e24013
[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748) 2024-10-25 21:31:05 +05:30
zjgarvey 1259e8a00a
Add Some Folders For Small Reshape Ops (#3813)
### Changes

1. Folders for view-like ops: `aten.view`, `aten.flatten.using_ints`,
and `aten.unflatten.int`
2. Folder for transpose
3. Extended support for the `aten.slice.Tensor` op folder to include
negative strides.


### Motivation

The biggest motivation for this patch is to fold the extremely
convoluted ir that gets generated when exporting a pytorch model with an
`aten.pad` op to ONNX, then re-importing and lowering back to torch. For
example, the verbose output of the e2e test `PadModule_basic` with `-c
onnx`:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %5 = torch.operator "onnx.Reshape"(%3, %4) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Reshape"(%11, %12) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    return %16 : !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _: "0x080000000400000000000000",
      __1: "0x080000000000000000000000010000000000000002000000000000000300000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __4: "0x080000000000000000000000",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x080000000100000000000080",
      __7: "0x08000000FFFFFFFFFFFFFFFF",
      __8: "0x08000000FFFFFFFFFFFFFFFF",
      __9: "0x080000000000C03F"
    }
  }
#-}
```

Get's converted to the torch IR:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
    %float1.500000e00 = torch.constant.float 1.500000e+00
    %int-9223372036854775807 = torch.constant.int -9223372036854775807
    %int-1 = torch.constant.int -1
    %int7 = torch.constant.int 7
    %int6 = torch.constant.int 6
    %int5 = torch.constant.int 5
    %int3 = torch.constant.int 3
    %int8 = torch.constant.int 8
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int4 = torch.constant.int 4
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<[0, 1, 2, 3, 0, 0, 0, 0]> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
    %1 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
    %3 = torch.aten.slice.Tensor %2, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
    %4 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
    %5 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
    %6 = torch.aten.view %4, %5 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
    %7 = torch.aten.slice.Tensor %6, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
    %9 = torch.aten.slice.Tensor %6, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
    %11 = torch.aten.slice.Tensor %6, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
    %13 = torch.aten.slice.Tensor %6, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
    %15 = torch.aten.slice.Tensor %6, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %16 = torch.aten.item %15 : !torch.vtensor<[1],si64> -> !torch.int
    %17 = torch.aten.slice.Tensor %6, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %18 = torch.aten.item %17 : !torch.vtensor<[1],si64> -> !torch.int
    %19 = torch.aten.slice.Tensor %6, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %20 = torch.aten.item %19 : !torch.vtensor<[1],si64> -> !torch.int
    %21 = torch.aten.slice.Tensor %6, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
    %23 = torch.prim.ListConstruct %14, %22, %12, %20, %10, %18, %8, %16 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %24 = torch.aten.constant_pad_nd %arg0, %23, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32>
    return %24 : !torch.vtensor<[?,?,?,?],f32>
  }
}
```

***All of these operations are useless***. It is literally the result of
needing to reverse (and change the lexicographic order hierarchy of)
padding ints provided via torch vs. ONNX pad ops, which is then
subsequently UNDONE by our ONNX->Torch lowering (represented in the
ordering of the generated list construct).

With the added folders in this patch, the torch IR becomes:

```
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
    %float1.500000e00 = torch.constant.float 1.500000e+00
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.constant_pad_nd %arg0, %0, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32>
    return %1 : !torch.vtensor<[?,?,?,?],f32>
  }
}
```
2024-10-24 12:09:00 -05:00
Sriram Kumar d6feb2179c
Added support for Maxpool (Autopad) (#3774)
Added autopad. and passed 3 tests 

test_maxpool_2d_precomputed_same_upper
test_maxpool_2d_same_lower'
test_maxpool_2d_same_upper

Address : https://github.com/nod-ai/SHARK-ModelDev/issues/843 

2 attributes yet to complete : storage_order, indices output
2024-10-23 13:04:50 +00:00
lingzhiz1998 2f9a68cc1e
Add canonicalization pattern for maxpool3d with indices op (#3704)
As discussed in https://github.com/llvm/torch-mlir/pull/3652, we should
replace maxpool3dwithindices with maxpool3d if indices have no user.
2024-10-23 18:31:20 +05:30
zjgarvey 55ff110dc2
[MLIR][TORCH] Only unroll prim loop-like ops within a `torch.shape.calculate` region (#3812)
Reports a match failure for the pattern `FullyUnrollPrimLoop` when the
loop op is not in a region defined by a `torch.shape.calculate` op.

This is needed to avoid unrolling prim loops generated by ONNX IR, since
we are applying shape refinement in the
`torch-onnx-to-torch-backend-pipeline` introduced in fa4794d .

See also the discussion in
<https://github.com/iree-org/iree/pull/18867#discussion_r1811101655>
2024-10-23 13:38:55 +05:30
Felix Schneider aca33f1742
[TorchToLinalg] Use Op with native channel order for quantized conv2d (#3807)
I've upstreamed the necessary quantized linalg Op with the
"channel-first" ordering used by torch
(https://github.com/llvm/llvm-project/pull/107740) for 2d convolution.

This patch changes the lowering for the quantized 2d case of
`aten.convolution` accordingly, which saves three transpositions per
convolution (input, weights, result) and therefore removes the
requirement to try to optimize these away in downstream passes.
2024-10-22 20:26:16 +02:00
zjgarvey 140cad5659
Add More Scalarize Shapes Patterns (#3810)
### new patterns:

1. Propagates `aten.broadcast_to` ops of a single value to an
`aten.full` op
2. Propagates arithmetic operations through a templated class which
associates some tensor arithmetic ops to their integer-scalar
counterparts. These are a major blocker right now, since some models
have a bunch of rank 0 arithmetic being done with tensor ops. See the
lit test for an interesting example that pads an input to the smallest
shape which will become divisible by twelve in `dim0`. If you think this
is convoluted, you haven't been staring at ONNX generated IR long
enough.
3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to
`false`. See the comment in that conversion pattern for more
justification as to why it is acceptable to make this assumption here.
This is another major blocker for models, since this lack of folding
propagates to lack of folding for subsequent `where.self` operations.
4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern`

### other changes:
 
1. Add two new anchor ops: `AtenArangeStartStepOp` and
`Torch::RuntimeAssertOp`. I've checked all possible sources of the
runtime assert ops and it is always shape related. The Arange op only
takes int inputs, and these are all shape related. Adds a size check to
getting a list from literal ops.
2. Improved folders for int arithmetic ops to fold some common patterns.
3. adds the ability to get some values from scalar-tensor ops to
getListFromTensor.
4. further cleans up getListFromTensor for readability.

### points to scrutinize:

1. I made the choice to scalarize `div.Tensor` (int dtype result) to
`floordiv.int`. This is because our shape computations involving this
kind of arithmetic are never negative in practice, and we don't have a
"round towards zero" scalar int divide counterpart.
2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if
someone happens to add a runtime assert in the future that doesn't boil
down to shapes, then it would add to the worklist considerably. We might
be able to get around this by adding "NoMemoryEffect" to ops which are
"ReadOnly" so that the inputs for the runtime asserts get cse'd with
existing elements of the worklist before we even get to this pass.
2024-10-21 19:42:39 -05:00
zjgarvey a83e106f92
Rework Scalarize Shapes Pass (#3799)
This is a first step towards reworking the scalarize-shapes pass which
has been integral to our ONNX frontend path detangling shape
computations.

## Purpose:

1. Restrict the scope of the pass to only apply to op sequences which
are used to compute shapes.
2. Make the pass more efficient by applying patterns in an appropriate
order for scalarization propagation.
3. Report failed scalarization patterns for easier debugging (Not yet
implemented). I can't seem to find a good path for this right now to
capture the right diagnostics. I'd like to defer this addition to a
later patch so we can add some high-value patterns to this pass in the
meantime.

With these changes, some reworking of the conversions themselves will be
necessary.

1. The removal of the SqueezeDim fold pattern was an appropriate fix to
avoid folding a pattern that may be needed to propagate further. The
reversal of pattern application order uncovered this bug. The addition
of rank 0 item logic was added to replace the functionality needed from
the squeeze dim pattern.
2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to
allow processing value tensor literals without immediately materializing
the ints. This should factor out a significant portion of code that was
used in specific cases to handle constants.

## RFC 1:

Currently, we are going to add all prim list of int ops to the worklist.
Can anyone identify problems with uniformly anchoring on prim lists of
ints? E.g. Does there exist a Torch Op satisfying all of the following
conditions:

1. Accepts a list of constant ints, LIST, as an input
2. The role of LIST is **not** shape related. All the examples I can
think of are indeed shape related: padding ints passed to a pad op,
kernel size ints passed to a conv op, size ints passed to a view op,
etc.
4. The LIST is not gotten entirely from scalars already. 

If there does not exist a torch op satisfying all three of those
conditions, I think it will be safe to "anchor" on prim lists of ints.

### Conclusion for RFC 1: 

I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for
all references of `AnyTorchListOfTorchIntType` and verified this will
not be problematic to apply in any of those cases.

## RFC 2:

What should I use to report failed scalarization?

Like my dumb idea was just to walk back through the func op after
applying the passes and check if anything in the worklist is still a
tensor. If so, emit/log a warning. It certainly works, since you can
just look at the warnings and start debugging from the last printed
warning upwards, but there has to be a better way to handle this without
walking back through the func.func op.

### Conclusion for RFC 2:

I tried a few things without much success. The fundamental problem is
that identifying the cause of a failed scalarization could be myriad:

1. We could be missing a pattern for an op entirely: E.g., a pattern we
need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp ->
AtenMulIntOp).
2. We could fail a scalarization pattern because it should fold instead.
This is specifically the case for rank0 where.self ops. These ops MUST
fold, or we need to have custom lowering logic for the rank 0 case.
3. Walking through the func op a second time and emiting a warning for
ops that have tensor result types seems to give locations that are
inconsistent or hard to track in the converted IR. Doing this on IR that
doesn't apply any patterns seems to give decent information, but it's
still dramatically insufficient considering how complex these patterns
can get, and still takes manually reading IR to try and figure out what
is really blocking the simplification.

I'd like to skip out on fleshing out the error reporting for now and
come back to it after iterating a few time on the patterns.
2024-10-21 12:47:19 -05:00
Vivek Khandelwal fa4794dae2
[MLIR][TORCH] Add torch-onnx-to-torch-backend pipeline (#3801)
This commit adds the torch-onnx-to-torch-backend pipeline which
converts the Torch Onnx IR to Torch Backend IR.

This commit also moves the `ScalarizeShapes` pass from the
`torch-backend-to-linalg-on-tensors-backend-pipeline` to the
`torch-onnx-to-torch-backend` pipeline since the primary goal of
this pass is to scalarize the shapes in the IR coming from the
Onnx models.
2024-10-21 11:20:44 -05:00
David Tanner 02327af998
Adds onnx ConvTranspose support for autopadding. (#3797)
Adds onnx ConvTranspose support for autopadding
(https://github.com/nod-ai/SHARK-ModelDev/issues/839).

- Adds support for attribute auto_pad="SAME_UPPER" or "SAME_LOWER" which
will automatically calculate padding of input based on output shape.
- Adds support, during auto-padding, for output_shape=[H,W] which
overrides the default output shape of input_shape[i]*stride[i] (for
spatial dimensions only).
- Adds lit test for auto-padding.
- Tests are added by https://github.com/nod-ai/SHARK-TestSuite/pull/370


NOTE: ConvTranspose still doesn't support asymmetric padding, therefore
multiple original onnx tests still won't pass.
2024-10-18 12:31:33 -05:00
Vivek Khandelwal 9c7067649b
build: manually update PyTorch version (#3727)
Set PyTorch and TorchVision version to nightly release 2024-10-15.

Tracker issue for the failing tests added to xfail_set in this PR.
Issue: https://github.com/llvm/torch-mlir/issues/3796
This commit disables the failing sparse tensor tests since they are not 
maintained on day-to-day basis and blocks the roll PyTorch update for now.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-18 13:32:14 +05:30
yyp0 dc7a1ff7d9
[Torch] add fold logic for some ops (#3794) 2024-10-16 16:00:58 +08:00
Justin Ngo 45bb17ebfe
[TOSA] Add legalization for empty, scatter, slice_scatter, diag_embed (#3792)
- Add Torch to TOSA legalization for the following ops:
  + aten.empty.memory_format
  + aten.scatter.src
  + aten.slice_scatter
  + aten.diag_embed
- Update xfail_sets.py with new e2e results
- Update basic.mlir with new LIT tests


Change-Id: I817ecf207bcfcf97ca54f30c10c76c4f0f4145ae

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-15 08:38:02 -07:00
Hanumanth04 895f490cf5
Remove checking for training specific parameters in EmbeddingBag lowering (#3782)
Torch-to-linalg pass fails for `EmbeddingBag` when the training only
specific properties of the operator are set to `true.` For instance,
this operator's `sparse` input/property is training-specific, and if the
value of this property is `true,` the existing lowering bails out.
However, we don't need to check for training-specific parameters and
bailout from the legalization since we don't care about these properties
during the eval/inference mode.

---------

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
2024-10-15 09:37:26 -04:00
zjgarvey 1e431c6a90
Add AtenSliceTOp Canonicalization to SimplifyShapeCalculations pass (#3791)
Some ops were failing to infer the static component of partially dynamic
shapes, and the cause was a missing aten.slice.t pattern.

The lit test included here is an IR dump created before
DropAbstractInterpCalculations for an unflatten op that was failing to
infer shapes before the change.
2024-10-14 14:41:31 -05:00
Marius Brehler edd1bbec46
Integrate LLVM at llvm/llvm-project@c13f806 (#3789) 2024-10-14 15:00:45 +02:00
yyp0 b176939808
[Torch] support 1d aten tensor shape and dtype infer (#3776) 2024-10-12 17:51:15 +08:00
zjgarvey ab62f35373
Add more patterns to scalarize-shapes pass (#3781)
-Adds patterns for propagating shapes through AtenWhereSelf and
AtenEqTensor
-Adds fold pattern for a rank0 squeezeDim of a full op 
-Adds support for getting a list from a splat ValueTensorLiteralOp for
materializing scalar comparisons in where.self and eq.tensor

With a bit of hammering, these changes should unblock several IREE
inference failures.
2024-10-11 11:15:17 -05:00