Commit Graph

409 Commits (main)

Author SHA1 Message Date
Giacomo Serafini 46a5772d92
[TorchToLinalg] Add `aten.fft_rfft` and lowering (#3857)
- Add `AtenFftRfftOp` to Torch dialect.
- Add conversion of `AtenFftRfftOp` to Linalg, using a `linalg.matmul`
per output component (real and imaginary). Computing the DFT is
_O(n^2)_.
- Add decomposition of `AtenFftRfftOp` into Torch-level ops (same
paradigm as above).
- Add unit and end-to-end tests.
2024-11-27 10:24:36 -06:00
Giacomo Serafini 44985690a7
[Torch Dialect] Emit `torch.aten.mul.float_int`, add folder and conversion to Arith. (#3750)
Folder is required to simplify the shape calculation of
`torch.aten.__interpolate.size_list_scale_list`:

5eab669c4a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp (L6900-L6907)

(I've re-run `build_tools/update_abstract_interp_lib.sh`)

---------

Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
2024-11-27 10:23:35 -06:00
zjgarvey 99115dcdc8
[Torch] Address unnecessary dynamic shapes in argmax decomposition (#3889)
Addresses <https://github.com/iree-org/iree/issues/19262#issue>
2024-11-22 16:03:29 -08:00
zjgarvey 0913b967ac
convert to double before float materialization in scalarize shapes (#3887)
Addresses a bug when trying to materialize a non fp64 attr to a constant
float op in scalarize shapes.
2024-11-22 14:05:24 -06:00
Giacomo Serafini 1b8d7e094b
[Torch Dialect] Add `torch.aten.mul.int_float` (required to simplify shape calculation of `upsample_nearest2d`) (#3764)
As per title. See also
[PR](https://github.com/llvm/torch-mlir/pull/3750) for
`torch.aten.mul.float_int`.

---------

Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
2024-11-21 00:43:06 +08:00
giacs-epic 06d17897f0
[Torch Dialect] Allow simplification of shape calculations of aten.tile, col2im, aten.stft (#3785)
- Add `aten.mul.left_t` (+ canonicalizer) to allow simplification of
aten.tile.
- Change syntax of the computation of col2im shape to allow the use of
an already existing canonicalization pattern (for `aten.add.t`) for its
simplification.
- Add `aten.eq.bool` ( + folder) to allow simplification of aten.stft.
2024-11-14 15:14:39 -06:00
zjgarvey cd38ecf6c2
Add Scalarization Patterns for `AtenToDtypeOp`, `AtenNegOp`, `AtenRemainderTensorOp` (#3861)
1. adds a lowering for `aten.neg.int` and `aten.remainder.int` to arith.
2. adds a scalarization pattern for `aten.neg` and
`aten.remainder.Tensor` ops.
3. improves folding of `aten.mul.int`
4. adds a scalarization pattern for `aten.to.dtype` which relies on
scalar cast ops and basic C++ casting between `double` and `int64_t`.
5. improves rank-0 case handling for `FoldAtenSplatPattern`
6. removes a bug with `aten.unflatten.int` decomposition incorrectly
generating a constant size int from a dynamic shape.
7. simplifies the dim list for `aten.unflatten.int` ops generated from
the `aten.view` canonicalization in scalarize shapes.

All of these changes were necessary to unblock
<https://github.com/iree-org/iree/issues/18899>.
2024-11-12 14:25:02 -06:00
zjgarvey 8519ecc4d7
Generalize `aten.view` pattern in scalarize shapes (#3856)
Extends the existing pattern to allow finding matching dims from the
back as well as the front.
2024-11-07 15:26:07 -06:00
zjgarvey 3104b66560
Fix Slice Folder OOB Crash and onnx.Shape lowering (#3843)
1. Clamps OOB start index to 0 in slice folder
2. Adds a more descriptive `emitError` in slice folder if the creation
of the `DenseElementsAttr` would fail due to a bad result shape.
3. Fixes the `onnx.Shape` lowering to default to `inputRank` for `end`
instead of `-1`. When `end==-1` the last element was missing when
slicing.
2024-11-01 15:33:21 -05:00
zjgarvey 738d45d3bb
add scalarization patterns to support dynamic pytorch pad exports (#3838)
1. Adds case handling for `aten.slice.tensor` shape inference with
negative strides. This is not technically allowed by native pytorch, but
it is useful for ONNX ingest. We were getting some incorrect shapes for
these negative strided slice ops.
2. Adds scalarization support for ops seen in pytorch pad exports to
ONNX. These are typically `aten.view` `aten.transpose.int` and
`aten.slice.Tensor` with negative strides (and rank 2).
3. Allows view op `self` to be added to the worklist conditionally,
based on whether the view op actually occurs as a middle point in a
shape computation.
2024-11-01 14:56:48 -05:00
zjgarvey 1259e8a00a
Add Some Folders For Small Reshape Ops (#3813)
### Changes

1. Folders for view-like ops: `aten.view`, `aten.flatten.using_ints`,
and `aten.unflatten.int`
2. Folder for transpose
3. Extended support for the `aten.slice.Tensor` op folder to include
negative strides.


### Motivation

The biggest motivation for this patch is to fold the extremely
convoluted ir that gets generated when exporting a pytorch model with an
`aten.pad` op to ONNX, then re-importing and lowering back to torch. For
example, the verbose output of the e2e test `PadModule_basic` with `-c
onnx`:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %2 = torch.operator "onnx.ConstantOfShape"(%0) {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %3 = torch.operator "onnx.Concat"(%1, %2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %5 = torch.operator "onnx.Reshape"(%3, %4) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %10 = torch.operator "onnx.Slice"(%5, %7, %8, %6, %9) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %11 = torch.operator "onnx.Transpose"(%10) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Reshape"(%11, %12) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %14 = torch.operator "onnx.Cast"(%13) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %16 = torch.operator "onnx.Pad"(%arg0, %14, %15) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    return %16 : !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _: "0x080000000400000000000000",
      __1: "0x080000000000000000000000010000000000000002000000000000000300000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __4: "0x080000000000000000000000",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x080000000100000000000080",
      __7: "0x08000000FFFFFFFFFFFFFFFF",
      __8: "0x08000000FFFFFFFFFFFFFFFF",
      __9: "0x080000000000C03F"
    }
  }
#-}
```

Get's converted to the torch IR:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
    %float1.500000e00 = torch.constant.float 1.500000e+00
    %int-9223372036854775807 = torch.constant.int -9223372036854775807
    %int-1 = torch.constant.int -1
    %int7 = torch.constant.int 7
    %int6 = torch.constant.int 6
    %int5 = torch.constant.int 5
    %int3 = torch.constant.int 3
    %int8 = torch.constant.int 8
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int4 = torch.constant.int 4
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<[0, 1, 2, 3, 0, 0, 0, 0]> : tensor<8xsi64>) : !torch.vtensor<[8],si64>
    %1 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
    %3 = torch.aten.slice.Tensor %2, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
    %4 = torch.aten.transpose.int %3, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
    %5 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
    %6 = torch.aten.view %4, %5 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
    %7 = torch.aten.slice.Tensor %6, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %8 = torch.aten.item %7 : !torch.vtensor<[1],si64> -> !torch.int
    %9 = torch.aten.slice.Tensor %6, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %10 = torch.aten.item %9 : !torch.vtensor<[1],si64> -> !torch.int
    %11 = torch.aten.slice.Tensor %6, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %12 = torch.aten.item %11 : !torch.vtensor<[1],si64> -> !torch.int
    %13 = torch.aten.slice.Tensor %6, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int
    %15 = torch.aten.slice.Tensor %6, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %16 = torch.aten.item %15 : !torch.vtensor<[1],si64> -> !torch.int
    %17 = torch.aten.slice.Tensor %6, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %18 = torch.aten.item %17 : !torch.vtensor<[1],si64> -> !torch.int
    %19 = torch.aten.slice.Tensor %6, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %20 = torch.aten.item %19 : !torch.vtensor<[1],si64> -> !torch.int
    %21 = torch.aten.slice.Tensor %6, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %22 = torch.aten.item %21 : !torch.vtensor<[1],si64> -> !torch.int
    %23 = torch.prim.ListConstruct %14, %22, %12, %20, %10, %18, %8, %16 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %24 = torch.aten.constant_pad_nd %arg0, %23, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32>
    return %24 : !torch.vtensor<[?,?,?,?],f32>
  }
}
```

***All of these operations are useless***. It is literally the result of
needing to reverse (and change the lexicographic order hierarchy of)
padding ints provided via torch vs. ONNX pad ops, which is then
subsequently UNDONE by our ONNX->Torch lowering (represented in the
ordering of the generated list construct).

With the added folders in this patch, the torch IR becomes:

```
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.5.0"} {
    %float1.500000e00 = torch.constant.float 1.500000e+00
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int0, %int1, %int2, %int3, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.constant_pad_nd %arg0, %0, %float1.500000e00 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,?,?,?],f32>
    return %1 : !torch.vtensor<[?,?,?,?],f32>
  }
}
```
2024-10-24 12:09:00 -05:00
lingzhiz1998 2f9a68cc1e
Add canonicalization pattern for maxpool3d with indices op (#3704)
As discussed in https://github.com/llvm/torch-mlir/pull/3652, we should
replace maxpool3dwithindices with maxpool3d if indices have no user.
2024-10-23 18:31:20 +05:30
zjgarvey 55ff110dc2
[MLIR][TORCH] Only unroll prim loop-like ops within a `torch.shape.calculate` region (#3812)
Reports a match failure for the pattern `FullyUnrollPrimLoop` when the
loop op is not in a region defined by a `torch.shape.calculate` op.

This is needed to avoid unrolling prim loops generated by ONNX IR, since
we are applying shape refinement in the
`torch-onnx-to-torch-backend-pipeline` introduced in fa4794d .

See also the discussion in
<https://github.com/iree-org/iree/pull/18867#discussion_r1811101655>
2024-10-23 13:38:55 +05:30
zjgarvey 140cad5659
Add More Scalarize Shapes Patterns (#3810)
### new patterns:

1. Propagates `aten.broadcast_to` ops of a single value to an
`aten.full` op
2. Propagates arithmetic operations through a templated class which
associates some tensor arithmetic ops to their integer-scalar
counterparts. These are a major blocker right now, since some models
have a bunch of rank 0 arithmetic being done with tensor ops. See the
lit test for an interesting example that pads an input to the smallest
shape which will become divisible by twelve in `dim0`. If you think this
is convoluted, you haven't been staring at ONNX generated IR long
enough.
3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to
`false`. See the comment in that conversion pattern for more
justification as to why it is acceptable to make this assumption here.
This is another major blocker for models, since this lack of folding
propagates to lack of folding for subsequent `where.self` operations.
4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern`

### other changes:
 
1. Add two new anchor ops: `AtenArangeStartStepOp` and
`Torch::RuntimeAssertOp`. I've checked all possible sources of the
runtime assert ops and it is always shape related. The Arange op only
takes int inputs, and these are all shape related. Adds a size check to
getting a list from literal ops.
2. Improved folders for int arithmetic ops to fold some common patterns.
3. adds the ability to get some values from scalar-tensor ops to
getListFromTensor.
4. further cleans up getListFromTensor for readability.

### points to scrutinize:

1. I made the choice to scalarize `div.Tensor` (int dtype result) to
`floordiv.int`. This is because our shape computations involving this
kind of arithmetic are never negative in practice, and we don't have a
"round towards zero" scalar int divide counterpart.
2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if
someone happens to add a runtime assert in the future that doesn't boil
down to shapes, then it would add to the worklist considerably. We might
be able to get around this by adding "NoMemoryEffect" to ops which are
"ReadOnly" so that the inputs for the runtime asserts get cse'd with
existing elements of the worklist before we even get to this pass.
2024-10-21 19:42:39 -05:00
zjgarvey a83e106f92
Rework Scalarize Shapes Pass (#3799)
This is a first step towards reworking the scalarize-shapes pass which
has been integral to our ONNX frontend path detangling shape
computations.

## Purpose:

1. Restrict the scope of the pass to only apply to op sequences which
are used to compute shapes.
2. Make the pass more efficient by applying patterns in an appropriate
order for scalarization propagation.
3. Report failed scalarization patterns for easier debugging (Not yet
implemented). I can't seem to find a good path for this right now to
capture the right diagnostics. I'd like to defer this addition to a
later patch so we can add some high-value patterns to this pass in the
meantime.

With these changes, some reworking of the conversions themselves will be
necessary.

1. The removal of the SqueezeDim fold pattern was an appropriate fix to
avoid folding a pattern that may be needed to propagate further. The
reversal of pattern application order uncovered this bug. The addition
of rank 0 item logic was added to replace the functionality needed from
the squeeze dim pattern.
2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to
allow processing value tensor literals without immediately materializing
the ints. This should factor out a significant portion of code that was
used in specific cases to handle constants.

## RFC 1:

Currently, we are going to add all prim list of int ops to the worklist.
Can anyone identify problems with uniformly anchoring on prim lists of
ints? E.g. Does there exist a Torch Op satisfying all of the following
conditions:

1. Accepts a list of constant ints, LIST, as an input
2. The role of LIST is **not** shape related. All the examples I can
think of are indeed shape related: padding ints passed to a pad op,
kernel size ints passed to a conv op, size ints passed to a view op,
etc.
4. The LIST is not gotten entirely from scalars already. 

If there does not exist a torch op satisfying all three of those
conditions, I think it will be safe to "anchor" on prim lists of ints.

### Conclusion for RFC 1: 

I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for
all references of `AnyTorchListOfTorchIntType` and verified this will
not be problematic to apply in any of those cases.

## RFC 2:

What should I use to report failed scalarization?

Like my dumb idea was just to walk back through the func op after
applying the passes and check if anything in the worklist is still a
tensor. If so, emit/log a warning. It certainly works, since you can
just look at the warnings and start debugging from the last printed
warning upwards, but there has to be a better way to handle this without
walking back through the func.func op.

### Conclusion for RFC 2:

I tried a few things without much success. The fundamental problem is
that identifying the cause of a failed scalarization could be myriad:

1. We could be missing a pattern for an op entirely: E.g., a pattern we
need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp ->
AtenMulIntOp).
2. We could fail a scalarization pattern because it should fold instead.
This is specifically the case for rank0 where.self ops. These ops MUST
fold, or we need to have custom lowering logic for the rank 0 case.
3. Walking through the func op a second time and emiting a warning for
ops that have tensor result types seems to give locations that are
inconsistent or hard to track in the converted IR. Doing this on IR that
doesn't apply any patterns seems to give decent information, but it's
still dramatically insufficient considering how complex these patterns
can get, and still takes manually reading IR to try and figure out what
is really blocking the simplification.

I'd like to skip out on fleshing out the error reporting for now and
come back to it after iterating a few time on the patterns.
2024-10-21 12:47:19 -05:00
Vivek Khandelwal fa4794dae2
[MLIR][TORCH] Add torch-onnx-to-torch-backend pipeline (#3801)
This commit adds the torch-onnx-to-torch-backend pipeline which
converts the Torch Onnx IR to Torch Backend IR.

This commit also moves the `ScalarizeShapes` pass from the
`torch-backend-to-linalg-on-tensors-backend-pipeline` to the
`torch-onnx-to-torch-backend` pipeline since the primary goal of
this pass is to scalarize the shapes in the IR coming from the
Onnx models.
2024-10-21 11:20:44 -05:00
yyp0 dc7a1ff7d9
[Torch] add fold logic for some ops (#3794) 2024-10-16 16:00:58 +08:00
zjgarvey 1e431c6a90
Add AtenSliceTOp Canonicalization to SimplifyShapeCalculations pass (#3791)
Some ops were failing to infer the static component of partially dynamic
shapes, and the cause was a missing aten.slice.t pattern.

The lit test included here is an IR dump created before
DropAbstractInterpCalculations for an unflatten op that was failing to
infer shapes before the change.
2024-10-14 14:41:31 -05:00
zjgarvey ab62f35373
Add more patterns to scalarize-shapes pass (#3781)
-Adds patterns for propagating shapes through AtenWhereSelf and
AtenEqTensor
-Adds fold pattern for a rank0 squeezeDim of a full op 
-Adds support for getting a list from a splat ValueTensorLiteralOp for
materializing scalar comparisons in where.self and eq.tensor

With a bit of hammering, these changes should unblock several IREE
inference failures.
2024-10-11 11:15:17 -05:00
zjgarvey 2665ed343b
adds a few common patterns to scalarize shapes pass (#3779)
This patch adds two things:

1. support for folding scalar patterns like [1]---squeeze--->[]
---unsqueeze--->[1].
2. a canonicalizer for aten.view that applies when we can statically or
dynamically (through the scalarized view shapes) infer that it is a
flatten or unflatten op in the last dim.

I'm not sure if this is the right place to be adding such a view
canonicalizer. Catastrophically, there is a decomposition from flatten
and unflatten into aten.view. Until this gets deleted (and it definitely
should be deleted), I felt like this would be an appropriate temporary
home. We run scalarize shapes after lowering to the backend contract
(i.e., decomposing), and scalarize shapes is required to be able to
infer dynamic dims coming from size int ops.
2024-10-10 10:16:45 -05:00
Rob Suderman 2374b9e02d
Bump to llvm/llvm-project@e813750354 (#3765)
Includes stablehlo bump
2024-10-04 12:08:35 -07:00
Prathamesh Tagore 617c1c76ce
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)
The op can be valid with no attached shape symbols if they are not
required by the corresponding affine map. Fix the verifier to consider
number of arguments for both.
2024-10-02 05:55:54 -07:00
Srinath Avadhanula 0a788e0467
Decompose aten.fmod into aten.mul,sub,div etc. (#3689)
As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.

This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-09 09:00:11 -07:00
Ze Zhang b3942ff984
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
2024-09-03 09:13:59 -07:00
Muhammad Abubakar 98e08023bb
Bump llvm to f9031f00f2c9 (#3672)
As title

---------

Co-authored-by: Muhammad Abubakar <jane.doe@getcruise.com>
2024-08-28 11:29:10 -07:00
Rob Suderman f9766c89f6
[onnx] Handle `torch.aten` for inner product case (#3634)
The following case was failing to lower for einsum. This fixes up the
inner product issue.
2024-08-24 11:41:25 -07:00
Vivek Khandelwal 0a86deb59a
build: manually update PyTorch version (#3627)
Set PyTorch and TorchVision version to nightly release 2024-08-18.
This commit also updates the `scaled_dot_product_attention` op. 
A new attribute `enable_gqa` has been added. As of now, only the
default value for the same is supported.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-19 12:03:56 +05:30
Felix Schneider 0314188dbe
[torch] Basic support for per-channel quantized graphs (#3623)
This patch adds basic support for lowering graphs with per-channel
quantization. Per-channel quantized ops have to be excluded from
`FuseQuantizedOps` for now but can be used in QDQ quantized form.

Using this patch, we're able to import and execute (on the linalg
backend) graphs with per-channel quantization applied using the "new"
PyTorch 2.0 Export Quantization.
2024-08-10 15:51:09 +02:00
Rob Suderman fd98476f77
[torch] Unpacking sometimes misses shape inference (#3609)
It is possible that the unpacked tensor does not match the same inferred
shapes. This is pretty common when ingesting form the `onnx` frontend.
2024-08-08 16:17:31 -07:00
yyp0 22cd4441e7
[Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566)
The static uneven divisible AdaptiveAvgPool2d means that although the
input size is not an integer multiple of ouput size, but the kernel and
stride size can also be fixed (not dynamic). The derivation logic of
kernel and stride size is consistent with
torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the
following:

1. Stride Size
Firstly , derive the start index in each reduce operation according to
the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size)
// output_size`. For each index `k`, if `k * (input_size % output_size)
< output_size`, then the current and previous stride keeps the same as
`input_size // output_size`. So suppose `(n-1) * (input_size %
output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d
process keeps static, as `input_size // output_size`.

2. Kernel Size
torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static
kernel size when the input/output sizes satisfy either of the two
conditions, `input_size % output_size == 0` or `output_size %
(input_size % output_size) == 0`. Here if `input_size % output_size ==
0`, then the kernel size equals `input_size // output_size`, otherwise
`input_size // output_size + 1.`
2024-08-01 11:37:53 +08:00
Rob Suderman 7f475e174e
Add extf-trunc f32-f64-f32 ellision (#3579)
Torch has all scalars represented as i64 and f64 types which results in
extraneous trunc-extf commands. We can rework this by elliding
widen-narrow cases away.
2024-07-31 16:50:00 -07:00
Yuanqiang Liu 003b06dfa1
[Torch] enhance naryFolderHelper to support mixed dtypes (#3559)
* so that it could support like `i64 + f64 => f64`.
* also unify `aten.log`'s folder code to use `naryFolderHelper`.
2024-07-24 17:54:59 +08:00
Yuanqiang Liu aad1604046
[Torch] enhance fold of aten.squeeze.dim (#3558) 2024-07-24 14:13:48 +08:00
Ze Zhang d1e172f418
Register fake_quantize_cachemask ops and add their decompose patterns (#3556)
Test:

`cmake --build build --target check-torch-mlir-all`
2024-07-23 11:33:12 -07:00
Yuanqiang Liu 21ad890009
[Torch] enhance fold of aten.slice.Tensor (#3557)
so that it could support folding slice with any static shape.
2024-07-23 22:53:03 +08:00
zjgarvey 0fb8b017d8
Adds misc fixes for some padding related issues (#3528)
This patch adds a few misc pad op related changes:

1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4

Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
2024-07-11 20:01:45 -05:00
Ze Zhang d466d5b809
Register fake_quantize related ops (#3522)
Register `aten.fake_quantize_per_channel_affine` and
`aten.fake_quantize_per_tensor_affine.tensor_qparams` ops

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-07-05 11:02:03 -07:00
Christopher McGirr 7e6d76e997
[Torch] Fix torch.constant.int operation parsing (#3476)
Due to the custom operation parser, the print and parser were expecting
two different forms.

One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
2024-06-28 16:06:52 +02:00
Sambhav Jain 09f502667b
`AtenTensorOp::fold` should not fold when result type is not fully specified (#3494)
In one of our downstreams, we encountered an internal assertion failure
in an intermediate pass from `AtenTensorOp::fold` invocation:
```
external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed.
```

for this snippet in the IR:
```
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>}
...
    %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list<int>
    %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
```

Turns out this was
[fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719)
eventually (and we were on an old hash of torch-mlir). This PR submits
just the lit test for test coverage on that specific change:
```c++
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
  auto resultTy = dyn_cast<ValueTensorType>(getType());
  // lit test this
  if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
    return nullptr;
  ...
```
2024-06-24 15:22:50 -07:00
Peiming Liu ba16bad8c7
[torch-mlir] bump stablehlo/llvm version (#3471)
Update to llvm/llvm-project@5207632f86
Update to openxla/stablehlo@d41390c3a7
2024-06-18 16:59:53 -07:00
zjgarvey de28c8540b
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
2024-06-12 10:37:22 +05:30
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Angel Zhang 2e194e13d6
[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)
This PR fixes the bugs for `Torch::AtenOneHotOp` by:

1) Using `Torch::kUnknownSize` as the default value for `numClasses` in
   the pattern matching stage in `DecomposeAtenOneHotOp`
2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith`
3) Handling both `int` and `float` types for `off` and `on` values in
`TorchOnnxToTorch` conversion

It also includes:

1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`,
and
2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot`

**Dependencies**

This PR is dependent on #3334.
2024-05-22 17:19:08 +00:00
Aaron St George ba32b9cee7
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
2024-05-16 00:07:45 +08:00
zjgarvey 75d1d72059
Generalize Operand Quantization in FuseQuantizeOps (#3327)
This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
2024-05-12 20:49:59 -07:00
Benoit Jacob bce800a3f4
Integrate llvm-project at dabdec1001dc368373dd581cf72f37a440873ce3 (#3300)
Co-authored-by: Jacques Pienaar <jpienaar@google.com>
2024-05-08 14:43:06 -04:00
Vivek Khandelwal e60160d793
Revert "Decompose AtenNonzeroOp" (#3289)
Reverts llvm/torch-mlir#3281
2024-05-06 09:52:04 -07:00
Xida Ren (Cedar) 1af00e6040
Decompose AtenNonzeroOp (#3281)
This fixes some onnx lit tests not lowering to linalg in
https://github.com/nod-ai/SHARK-Turbine/issues/450
2024-05-05 21:59:25 +08:00
Ze Zhang 11cd7cd9e7
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues:

- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type

- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user

This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp


Lit test:

`cmake --build build --target check-torch-mlir-all`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-05-02 00:03:41 -07:00
Yuanqiang Liu aed2cf3351
[Torch] emit aten.__contains__.str_list and add folder (#3249) 2024-04-29 10:51:17 +08:00