Commit Graph

3194 Commits (42ba541c6887f0fa2d57c896ded219cb420591ed)
 

Author SHA1 Message Date
Felix Schneider 42ba541c68
[fx] Fix importing and tests for quantized conv (#3809)
The fx tracer does not support tracing "real" quantized tensors
currently. A "real" quantized tensor here means a tensor that is created
using a method like `torch.quantize_per_tensor()` and carries the
quantization parameters (scale, zero_point, scheme) in the object.
However, it seems like the DQ-Q type fake quantizatation is now commonly
used as a high level representation of quantized operators and is only
lowered to native quantized ops (if available) in the respective
hardware backend. Quantization of floating point modules in PyTorch is
recently also performed as a graph transformation after
exporting/tracing the original module.

```python
# Examples of "real"/native quantization
tens = torch.randint(-127, 127, (1,), dtype=torch.int8)
torch._make_per_tensor_quantized_tensor(tens, 1, 0)
# tensor([90.], size=(1,), dtype=torch.qint8,
#       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

tens = torch.rand((1,))
torch.quantize_per_tensor(tens, 1, 0, torch.qint8)
# tensor([1.], size=(1,), dtype=torch.qint8,
#       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

# Example of DQ/Q quantization
import torch.ao.quantization.fx._decomposed
tens = torch.rand((1,))
torch.ops.quantized_decomposed.quantize_per_tensor.default(tens, 1, 0, -128, 127, torch.int8)
# tensor([1], dtype=torch.int8)
```

This means that a typical import flow for a quantized network
into/through torch-mlir would look like this:
`torch.export() -> quantization transformations on fx graph ->
fx_importer` Where the tensors in the graph are normal float/int tensors
and the quantization parameters are carried by the DQ/Q ops. These kinds
of graphs can be traced without issues.

Currently, our quantized convolution tests use the "real" quantized
tensors. This means that with the retirement of the `jit_ir_importer`,
these tests cannot be imported any longer. In summary, I see no reason
to stick to the "real" quantization in these tests, as both PyTorch 2.0
is using DQ/Q quantization and our linalg backend is also using it.

This patch updates our quantized convolution tests to use the DQ-Q
quantization with the ops from `torch.ops.quantized_decomposed`.

Note: For future reference, there seems to be an ongoing consolidation
of the ops for the DQ/Q scheme on the PyTorch side
(https://github.com/pytorch/ao/issues/986#issuecomment-2390296826).
2024-10-22 18:37:57 +02:00
zjgarvey 140cad5659
Add More Scalarize Shapes Patterns (#3810)
### new patterns:

1. Propagates `aten.broadcast_to` ops of a single value to an
`aten.full` op
2. Propagates arithmetic operations through a templated class which
associates some tensor arithmetic ops to their integer-scalar
counterparts. These are a major blocker right now, since some models
have a bunch of rank 0 arithmetic being done with tensor ops. See the
lit test for an interesting example that pads an input to the smallest
shape which will become divisible by twelve in `dim0`. If you think this
is convoluted, you haven't been staring at ONNX generated IR long
enough.
3. Adds a stronger folder for `aten.eq.int` to fold `size.int == 0` to
`false`. See the comment in that conversion pattern for more
justification as to why it is acceptable to make this assumption here.
This is another major blocker for models, since this lack of folding
propagates to lack of folding for subsequent `where.self` operations.
4. Add `AtenSqueezeDim` to the existing `FoldAtenSqueezeOpPattern`

### other changes:
 
1. Add two new anchor ops: `AtenArangeStartStepOp` and
`Torch::RuntimeAssertOp`. I've checked all possible sources of the
runtime assert ops and it is always shape related. The Arange op only
takes int inputs, and these are all shape related. Adds a size check to
getting a list from literal ops.
2. Improved folders for int arithmetic ops to fold some common patterns.
3. adds the ability to get some values from scalar-tensor ops to
getListFromTensor.
4. further cleans up getListFromTensor for readability.

### points to scrutinize:

1. I made the choice to scalarize `div.Tensor` (int dtype result) to
`floordiv.int`. This is because our shape computations involving this
kind of arithmetic are never negative in practice, and we don't have a
"round towards zero" scalar int divide counterpart.
2. Anchoring on `RuntimeAssertOp` sounds really suspicious, and if
someone happens to add a runtime assert in the future that doesn't boil
down to shapes, then it would add to the worklist considerably. We might
be able to get around this by adding "NoMemoryEffect" to ops which are
"ReadOnly" so that the inputs for the runtime asserts get cse'd with
existing elements of the worklist before we even get to this pass.
2024-10-21 19:42:39 -05:00
zjgarvey a83e106f92
Rework Scalarize Shapes Pass (#3799)
This is a first step towards reworking the scalarize-shapes pass which
has been integral to our ONNX frontend path detangling shape
computations.

## Purpose:

1. Restrict the scope of the pass to only apply to op sequences which
are used to compute shapes.
2. Make the pass more efficient by applying patterns in an appropriate
order for scalarization propagation.
3. Report failed scalarization patterns for easier debugging (Not yet
implemented). I can't seem to find a good path for this right now to
capture the right diagnostics. I'd like to defer this addition to a
later patch so we can add some high-value patterns to this pass in the
meantime.

With these changes, some reworking of the conversions themselves will be
necessary.

1. The removal of the SqueezeDim fold pattern was an appropriate fix to
avoid folding a pattern that may be needed to propagate further. The
reversal of pattern application order uncovered this bug. The addition
of rank 0 item logic was added to replace the functionality needed from
the squeeze dim pattern.
2. Rework getListFromTensor to modify a `SmallVector<OpFoldResult>` to
allow processing value tensor literals without immediately materializing
the ints. This should factor out a significant portion of code that was
used in specific cases to handle constants.

## RFC 1:

Currently, we are going to add all prim list of int ops to the worklist.
Can anyone identify problems with uniformly anchoring on prim lists of
ints? E.g. Does there exist a Torch Op satisfying all of the following
conditions:

1. Accepts a list of constant ints, LIST, as an input
2. The role of LIST is **not** shape related. All the examples I can
think of are indeed shape related: padding ints passed to a pad op,
kernel size ints passed to a conv op, size ints passed to a view op,
etc.
4. The LIST is not gotten entirely from scalars already. 

If there does not exist a torch op satisfying all three of those
conditions, I think it will be safe to "anchor" on prim lists of ints.

### Conclusion for RFC 1: 

I just scanned through the `GeneratedTorchOps.td` and `TorchOps.td` for
all references of `AnyTorchListOfTorchIntType` and verified this will
not be problematic to apply in any of those cases.

## RFC 2:

What should I use to report failed scalarization?

Like my dumb idea was just to walk back through the func op after
applying the passes and check if anything in the worklist is still a
tensor. If so, emit/log a warning. It certainly works, since you can
just look at the warnings and start debugging from the last printed
warning upwards, but there has to be a better way to handle this without
walking back through the func.func op.

### Conclusion for RFC 2:

I tried a few things without much success. The fundamental problem is
that identifying the cause of a failed scalarization could be myriad:

1. We could be missing a pattern for an op entirely: E.g., a pattern we
need is scalarizing rank0 arithmetic ops (e.g. AtenMulTensorOp ->
AtenMulIntOp).
2. We could fail a scalarization pattern because it should fold instead.
This is specifically the case for rank0 where.self ops. These ops MUST
fold, or we need to have custom lowering logic for the rank 0 case.
3. Walking through the func op a second time and emiting a warning for
ops that have tensor result types seems to give locations that are
inconsistent or hard to track in the converted IR. Doing this on IR that
doesn't apply any patterns seems to give decent information, but it's
still dramatically insufficient considering how complex these patterns
can get, and still takes manually reading IR to try and figure out what
is really blocking the simplification.

I'd like to skip out on fleshing out the error reporting for now and
come back to it after iterating a few time on the patterns.
2024-10-21 12:47:19 -05:00
Vivek Khandelwal fa4794dae2
[MLIR][TORCH] Add torch-onnx-to-torch-backend pipeline (#3801)
This commit adds the torch-onnx-to-torch-backend pipeline which
converts the Torch Onnx IR to Torch Backend IR.

This commit also moves the `ScalarizeShapes` pass from the
`torch-backend-to-linalg-on-tensors-backend-pipeline` to the
`torch-onnx-to-torch-backend` pipeline since the primary goal of
this pass is to scalarize the shapes in the IR coming from the
Onnx models.
2024-10-21 11:20:44 -05:00
Vivek Khandelwal d2330df58f
build: manually update PyTorch version (#3808)
Set PyTorch and TorchVision version to nightly release 2024-10-20.
2024-10-21 17:26:09 +05:30
Felix Schneider f5d15ab20e
Bump LLVM to llvm/llvm-project@f0b3b6d1 (#3806) 2024-10-20 09:32:21 -07:00
Sayan Saha 09cdbe4c47
Disable building STABLEHLO and specify USE_MATH_DEFINES for windows builds. (#3805)
I'm trying to build python wheel for windows similar to as done for
linux in https://github.com/llvm/torch-mlir-release/ however turns out
the build process on windows is broken without the following fixes:

1. Building stablehlo for windows fails due to
https://github.com/openxla/stablehlo/issues/1772 -- so disabling
stablehlo in `build_windows_ci.sh` that will be used for building the
python wheels.

2. Add `USE_MATH_DEFINES` to resolve
`torch-mlir\lib\Conversion\TorchOnnxToTorch\DefaultDomainGtoP.cpp(709):
error C2065: 'M_LOG10E': undeclared identifier`
2024-10-18 12:04:37 -07:00
David Tanner 02327af998
Adds onnx ConvTranspose support for autopadding. (#3797)
Adds onnx ConvTranspose support for autopadding
(https://github.com/nod-ai/SHARK-ModelDev/issues/839).

- Adds support for attribute auto_pad="SAME_UPPER" or "SAME_LOWER" which
will automatically calculate padding of input based on output shape.
- Adds support, during auto-padding, for output_shape=[H,W] which
overrides the default output shape of input_shape[i]*stride[i] (for
spatial dimensions only).
- Adds lit test for auto-padding.
- Tests are added by https://github.com/nod-ai/SHARK-TestSuite/pull/370


NOTE: ConvTranspose still doesn't support asymmetric padding, therefore
multiple original onnx tests still won't pass.
2024-10-18 12:31:33 -05:00
Vivek Khandelwal 9c7067649b
build: manually update PyTorch version (#3727)
Set PyTorch and TorchVision version to nightly release 2024-10-15.

Tracker issue for the failing tests added to xfail_set in this PR.
Issue: https://github.com/llvm/torch-mlir/issues/3796
This commit disables the failing sparse tensor tests since they are not 
maintained on day-to-day basis and blocks the roll PyTorch update for now.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-18 13:32:14 +05:30
yyp0 dc7a1ff7d9
[Torch] add fold logic for some ops (#3794) 2024-10-16 16:00:58 +08:00
penguin_wwy 6b289f29f2
[FxImporter] Added FxImporter test method to be executed via torch.co… (#3795) 2024-10-16 10:32:52 +08:00
Justin Ngo 45bb17ebfe
[TOSA] Add legalization for empty, scatter, slice_scatter, diag_embed (#3792)
- Add Torch to TOSA legalization for the following ops:
  + aten.empty.memory_format
  + aten.scatter.src
  + aten.slice_scatter
  + aten.diag_embed
- Update xfail_sets.py with new e2e results
- Update basic.mlir with new LIT tests


Change-Id: I817ecf207bcfcf97ca54f30c10c76c4f0f4145ae

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-15 08:38:02 -07:00
Hanumanth04 895f490cf5
Remove checking for training specific parameters in EmbeddingBag lowering (#3782)
Torch-to-linalg pass fails for `EmbeddingBag` when the training only
specific properties of the operator are set to `true.` For instance,
this operator's `sparse` input/property is training-specific, and if the
value of this property is `true,` the existing lowering bails out.
However, we don't need to check for training-specific parameters and
bailout from the legalization since we don't care about these properties
during the eval/inference mode.

---------

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
2024-10-15 09:37:26 -04:00
zjgarvey 1e431c6a90
Add AtenSliceTOp Canonicalization to SimplifyShapeCalculations pass (#3791)
Some ops were failing to infer the static component of partially dynamic
shapes, and the cause was a missing aten.slice.t pattern.

The lit test included here is an IR dump created before
DropAbstractInterpCalculations for an unflatten op that was failing to
infer shapes before the change.
2024-10-14 14:41:31 -05:00
Marius Brehler edd1bbec46
Integrate LLVM at llvm/llvm-project@c13f806 (#3789) 2024-10-14 15:00:45 +02:00
yyp0 b176939808
[Torch] support 1d aten tensor shape and dtype infer (#3776) 2024-10-12 17:51:15 +08:00
zjgarvey ab62f35373
Add more patterns to scalarize-shapes pass (#3781)
-Adds patterns for propagating shapes through AtenWhereSelf and
AtenEqTensor
-Adds fold pattern for a rank0 squeezeDim of a full op 
-Adds support for getting a list from a splat ValueTensorLiteralOp for
materializing scalar comparisons in where.self and eq.tensor

With a bit of hammering, these changes should unblock several IREE
inference failures.
2024-10-11 11:15:17 -05:00
yyp0 7b11dfc0ee
[Torch] support adaptive_max_pool1d when return_indices equals False (#3783) 2024-10-11 23:42:15 +08:00
Ian Wood 8787970afe
[Torch] Fold no-op reshape (#3769)
This was preventing dynamic dims in an ONNX model from being reified (causing the generation of `tensor.cast`s and preventing fusion in iree):

```mlir
%2 = torch.vtensor.literal(dense<[4, 256]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>]
%7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%8 = torch.aten.reshape %2, %7 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
//... chain of foldable ops linking %2 to the `shape` operand of a `torch.aten.broadcast_to ... -> !torch.vtensor<[?,?],si64>`
```
2024-10-10 18:54:27 -07:00
zjgarvey 2665ed343b
adds a few common patterns to scalarize shapes pass (#3779)
This patch adds two things:

1. support for folding scalar patterns like [1]---squeeze--->[]
---unsqueeze--->[1].
2. a canonicalizer for aten.view that applies when we can statically or
dynamically (through the scalarized view shapes) infer that it is a
flatten or unflatten op in the last dim.

I'm not sure if this is the right place to be adding such a view
canonicalizer. Catastrophically, there is a decomposition from flatten
and unflatten into aten.view. Until this gets deleted (and it definitely
should be deleted), I felt like this would be an appropriate temporary
home. We run scalarize shapes after lowering to the backend contract
(i.e., decomposing), and scalarize shapes is required to be able to
infer dynamic dims coming from size int ops.
2024-10-10 10:16:45 -05:00
yyp0 d0041dc310
[stablehlo] support aten.view.dtype lowering (#3778) 2024-10-10 15:50:17 +08:00
Vivek Khandelwal 94f5410913
[LINALG] Add complex tensor support for `create[Zero|One]InitTensor` utility (#3777)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-09 16:15:08 +05:30
Stephen Baione d49eabb3fc
Add Op for `torch.aten.unfold` (#3772)
# Description

Implementation of the op for `torch.aten.unfold`: [TorchToLinalg Op
Support #347](https://github.com/nod-ai/SHARK-ModelDev/issues/849)

Documentation of op can be found here: [PyTorch
Docs](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)

For this op, we apply a sliding window of some `size` along a single
`dimension`, with `step` in between iterations.

`Declaration: aten::unfold(Tensor(a) self, int dimension, int size, int
step) -> Tensor(a)`

The resulting `unfolded` tensor modifies the shape of `dimension` to be
equal to the number of blocks that the sliding windows extracts/inserts,
with an additional dimension of `size` appended (the number of cols of
the output tensor directly translates from the size of the sliding
window).

So if we had a tensor of rank 3 (A x B x C), with dimension = 1, size =
2 and step = 2:

    (A x B x C) |=> (A x (B - size) // step + 1 x C x size)

After extracting the window from the input tensor, we insert the (1 x
size) slice into the output tensor. We can make this simpler by mapping
the output indices from the input indices, like they do in the official
implementation:

[PyTorch
Code](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L1694)
2024-10-08 21:10:43 +00:00
Phaneesh Barwaria 7830c00ca2
onnx.LSTM - bidirectional, layout attr (#3771)
- Support Bidirectional LSTM (utilising the forward LSTM layer with
flipped Inputs and Outputs)
- Support layout 1 
- Support default cases for attr `clip` and `input_forget`
- Support returning partial outputs (1-3)  
- fixes for alt_e2e_tests lstm tests (1,2,3)
2024-10-08 11:29:49 -07:00
jinchen 58489faf7f
torch.aten.squeeze.dim lowering with dynamic dims (#3749)
Address https://github.com/nod-ai/SHARK-ModelDev/issues/846

Assume the dynamic squeezed dim is 1.
2024-10-08 10:37:31 -07:00
Vivek Khandelwal 614fcdd153
[MLIR][TORCH] Add support for 1-d group convolution (#3770)
This commit adds the support for the 1-d depthwise convolution as a
special case of 1-d group convolution.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-08 10:48:47 +05:30
Vivek Khandelwal f6721e5999
[MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (#3763)
This commit adds the support for negative step values in
aten.slice.Tensor op. Although, PyTorch does not allow negative step
value for slice op but the Onnx.Slice op supports negative step value
which eventually lowers to torch.aten.slice.Tensor op. Hence, the
support is added for handling those kind of values during the
Torch->Linalg lowering of aten.slice.Tensor op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-10-08 10:34:27 +05:30
Justin Ngo b08d08682f
[TOSA] Add legalization for fill, flip, and round (#3768)
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and
aten.round
- Fix torchScalarToTosaTensor function to correctly convert Torch scalar
input to TOSA tensor
- Update xfail_sets.py with new e2e results
- Update basic.mlir with LIT tests for new ops


Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-07 10:28:26 -07:00
Chi_Liu f4840ed886
[ONNX] Fix onnx.ScatterElements with AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext dialect (#3754)
- To fix issue onnx.ScatterElements: https://github.com/nod-ai/SHARK-ModelDev/issues/823
- E2E test: https://github.com/nod-ai/SHARK-TestSuite/pull/363
2024-10-05 22:22:41 -07:00
Rob Suderman 53f7532e76
Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767)
Reverted due to downstream model changes. Will reland with fixes post
integration.

This reverts commit 6e8c7bed4b.
2024-10-04 14:48:02 -07:00
Justin Ngo e9ed4af9ce
[TOSA] Add legalization for aten.index_select (#3760)
- Add Torch to TOSA legalization for aten.index_select
- Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to
correctly convert Torch indices to TF-style indices, which is used in
convertGatherNdOp
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new LIT test for aten.index_select

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-10-04 12:24:22 -07:00
Rob Suderman 2374b9e02d
Bump to llvm/llvm-project@e813750354 (#3765)
Includes stablehlo bump
2024-10-04 12:08:35 -07:00
zjgarvey 6e8c7bed4b
[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)
This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <https://github.com/iree-org/iree/issues/18677>
2. <https://github.com/iree-org/iree/issues/18631>
2024-10-04 11:27:00 -05:00
zjgarvey f08bfc4ff8
[ONNX] simplify shapes fed to broadcast in Expand lowering (#3756)
Addresses ~200 onnx model compile failures in
<https://github.com/nod-ai/SHARK-TestSuite> related to
<https://github.com/iree-org/iree/issues/18631>.

This change simplifies the result of the generated broadcast op
substantially, but reduces the case coverage slightly.

The case which will become unsupported: 
- trying to actually broadcast a dynamic dim that is secretly 1. 

When does this case appear in practical scenarios?
- for a model where onnx shape inference cannot figure out that a dim
should be 1.

Why do I think we should not support this case for now?
1. For all models with dynamic dim expand ops, the previous path
uniformly generates uglier linalg IR (making it harder for IREE to fuse
properly with other ops).
2. For models failing shape inference castastrophically enough to fail
to see a dim is statically 1, we can try to apply constant folding in
the onnx model before importing.

Leaving this as a draft PR, since it may be more appropriate to fix the
compilation failure in IREE rather than torch-mlir.

### Example of broadcast required in previous path:

```mlir
    %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) {
    ^bb0(%out: i1):
      %306 = linalg.index 0 : index
      %307 = linalg.index 3 : index
      %308 = arith.index_cast %285 : i64 to index
      %309 = arith.cmpi eq, %308, %c1 : index
      %310 = arith.select %309, %c0, %306 : index
      %311 = arith.index_cast %286 : i64 to index
      %312 = arith.cmpi eq, %311, %c1 : index
      %313 = arith.select %312, %c0, %307 : index
      %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1>
      linalg.yield %extracted_79 : i1
    } -> tensor<?x12x?x?xi1>
```

### Example of broadcast with simplified shape list:

```mlir
    %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) {
    ^bb0(%in: i1, %out: i1):
      linalg.yield %in : i1
    } -> tensor<?x12x?x?xi1>
```
2024-10-03 20:11:51 -05:00
Rob Suderman 9ab0db5789
[torch] `torch.aten.complex` operation with lowering (#3738)
Add the operation with lowering to linalg. Includes a test for
end-to-end correctness.
2024-10-03 11:09:52 -07:00
Kyle Wang f0b7ca72f5
Fixed GRU quality issues exposed by e2e tests (#3753)
Issue: https://github.com/nod-ai/SHARK-ModelDev/issues/856

Related tests:
![Screenshot 2024-10-01
175305](https://github.com/user-attachments/assets/0dc0901b-058f-427c-a596-9e806fd38836)
2024-10-02 17:00:19 -04:00
Sambhav Jain f8e4a9a3c2
[Release] Fix binary name for downstream compatibility (#3752)
As of Sep 14, the torch-mlir binary
[wheels](https://github.com/llvm/torch-mlir-release/releases/tag/dev-wheels)
got renamed to `torch-mlir-core` from `torch-mlir`:
![image](https://github.com/user-attachments/assets/152e4977-71ef-4f57-8757-6dc75f72b670)

This was an unintended side-effect of the recent change of
`TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=True`
(https://github.com/llvm/torch-mlir/pull/3711) which skips setting `NAME
= "torch-mlir"` in
[setup.py](https://github.com/llvm/torch-mlir/blob/main/setup.py#L226-L232).

To avoid having multiple downstreams fix their pip deps, this change
allows using the same `torch-mlir` name for binaries, and reserves a
separate `torch-mlir-ext` name for the (less popular) binaries with
extensions enabled.
2024-10-02 11:52:20 -07:00
Samu Tamminen a2bfe47faa
[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)
Address https://github.com/nod-ai/SHARK-Turbine/issues/833
2024-10-02 08:17:58 -05:00
Prathamesh Tagore 617c1c76ce
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751)
The op can be valid with no attached shape symbols if they are not
required by the corresponding affine map. Fix the verifier to consider
number of arguments for both.
2024-10-02 05:55:54 -07:00
Marius Brehler b1413a6c7f
Update instructions on creating a virtual env (#3724)
The `python` command is only available on Ubuntu if the
`python-is-python3` package is installed, see
https://packages.ubuntu.com/jammy/python-is-python3 and
https://packages.ubuntu.com/jammy/all/python-is-python3/filelist. As
Python 2 isn't supported anyway, it's safe to point to `python3` here
instead.
2024-10-01 19:12:11 +02:00
Justin Ngo 5eab669c4a
[TOSA] Add legalization for aten.diagonal (#3740)
- Add lowering from Torch to TOSA for aten.diagonal
- Clean up some code
- Update xfail_sets.py with the new e2e results
- Update basic_mlir with the new op mlir test

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I99bed685455752d09ed96edd837c4dfbee152701

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
2024-09-30 08:24:31 -07:00
Yuanqiang Liu 5f74de5ba0
[Stablehlo] support aten.all.dim (#3746) 2024-09-30 15:59:27 +08:00
yyp0 eb4e59e189
[Torch] support binary_cross_entropy_with_logits decomposition (#3741) 2024-09-29 17:41:20 +08:00
jinchen a33d1232c5
[onnx] Fix onnx.Shape lowering with scalar input (#3716)
Address https://github.com/nod-ai/SHARK-Turbine/issues/826
2024-09-27 13:30:02 -07:00
Xida Ren (Cedar) 9938abf25e
AtenCumprodOp (#3737) 2024-09-26 18:17:22 -04:00
yyp0 335cf5f6d0
[stablehlo] support aten_adaptive_max_pool1d lowering (#3728) 2024-09-26 11:42:38 +08:00
Xida Ren (Cedar) aa7e77ee64
Better errmsg upon getScalarTypeForType failure (#3734)
Instead of 
`Unhandled type in getScalarTypeForType`

You now get

Unhandled type in getScalarTypeForType: (type name)
Type properties:
  Is integer: yes
  Bit width: 
...


The root cause is https://github.com/llvm/torch-mlir/issues/3720, at
least for unsigned integer issues.
2024-09-25 16:32:26 +00:00
Vinayak Dev 67732883fa
[torch] Fix unsqueezed output shape in canonicalization of AtenUnflattenIntOp (#3730)
Fixes https://github.com/iree-org/iree/issues/18562.

During canonicalization pass on `AtenUnflattenIntOp`, if the second dim
was statically equal to one, we would create an `AtenAddIntOp` to add
one to the dimension obtained from `op.getDim()`. This, when passed into
`Torch::unsqueezeTensor()`, would make it get interpreted as
non-constant, which would lead to MLIR failing an assertion when
`UnsqueezeOp` would later get lowered into `ExpandShapeOp`, as the
output of the `UnsqueezeOp` would consist of only dynamic dims.

This patch fixes this behavior, by extracting the integer value from the
dim if it was constant, and then emitting a `ConstantIntOp` from
(dim+1). This creates an output with static shape.
2024-09-24 11:45:18 -05:00
Marius Brehler e4f2bdf0db
Document requirements for `torch_mlir_e2e_test` (#3722)
This documents which CMake options must be set to be able to use
`torch_mlir_e2e_test`, required e.g. for
`projects/pt1/tools/e2e_test.sh`.

Makes progress on #3696.
Closes #3719.
2024-09-24 00:06:54 +02:00
giacs-epic 99848265c3
[onnx] Relax constraints on input tensors in `onnx.STFT` conversion to torch dialect (#3676)
- When the signal tensor is real, onnx allows its shape to be
`[batch][length]` as well as `[batch][length][1]`.
- Onnx also allows to specify `frame_length` together with `window` (not
empty), given that it matches the window size.
- Adding checks on signal and result shapes.
2024-09-23 12:09:29 +05:30