Commit Graph

387 Commits (b1413a6c7fbfcbe7036a12b788f0cdfce729a105)

Author SHA1 Message Date
Srinath Avadhanula 0a788e0467
Decompose aten.fmod into aten.mul,sub,div etc. (#3689)
As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.

This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
2024-09-09 09:00:11 -07:00
Ze Zhang b3942ff984
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
2024-09-03 09:13:59 -07:00
Muhammad Abubakar 98e08023bb
Bump llvm to f9031f00f2c9 (#3672)
As title

---------

Co-authored-by: Muhammad Abubakar <jane.doe@getcruise.com>
2024-08-28 11:29:10 -07:00
Rob Suderman f9766c89f6
[onnx] Handle `torch.aten` for inner product case (#3634)
The following case was failing to lower for einsum. This fixes up the
inner product issue.
2024-08-24 11:41:25 -07:00
Vivek Khandelwal 0a86deb59a
build: manually update PyTorch version (#3627)
Set PyTorch and TorchVision version to nightly release 2024-08-18.
This commit also updates the `scaled_dot_product_attention` op. 
A new attribute `enable_gqa` has been added. As of now, only the
default value for the same is supported.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-08-19 12:03:56 +05:30
Felix Schneider 0314188dbe
[torch] Basic support for per-channel quantized graphs (#3623)
This patch adds basic support for lowering graphs with per-channel
quantization. Per-channel quantized ops have to be excluded from
`FuseQuantizedOps` for now but can be used in QDQ quantized form.

Using this patch, we're able to import and execute (on the linalg
backend) graphs with per-channel quantization applied using the "new"
PyTorch 2.0 Export Quantization.
2024-08-10 15:51:09 +02:00
Rob Suderman fd98476f77
[torch] Unpacking sometimes misses shape inference (#3609)
It is possible that the unpacked tensor does not match the same inferred
shapes. This is pretty common when ingesting form the `onnx` frontend.
2024-08-08 16:17:31 -07:00
yyp0 22cd4441e7
[Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566)
The static uneven divisible AdaptiveAvgPool2d means that although the
input size is not an integer multiple of ouput size, but the kernel and
stride size can also be fixed (not dynamic). The derivation logic of
kernel and stride size is consistent with
torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the
following:

1. Stride Size
Firstly , derive the start index in each reduce operation according to
the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size)
// output_size`. For each index `k`, if `k * (input_size % output_size)
< output_size`, then the current and previous stride keeps the same as
`input_size // output_size`. So suppose `(n-1) * (input_size %
output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d
process keeps static, as `input_size // output_size`.

2. Kernel Size
torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static
kernel size when the input/output sizes satisfy either of the two
conditions, `input_size % output_size == 0` or `output_size %
(input_size % output_size) == 0`. Here if `input_size % output_size ==
0`, then the kernel size equals `input_size // output_size`, otherwise
`input_size // output_size + 1.`
2024-08-01 11:37:53 +08:00
Rob Suderman 7f475e174e
Add extf-trunc f32-f64-f32 ellision (#3579)
Torch has all scalars represented as i64 and f64 types which results in
extraneous trunc-extf commands. We can rework this by elliding
widen-narrow cases away.
2024-07-31 16:50:00 -07:00
Yuanqiang Liu 003b06dfa1
[Torch] enhance naryFolderHelper to support mixed dtypes (#3559)
* so that it could support like `i64 + f64 => f64`.
* also unify `aten.log`'s folder code to use `naryFolderHelper`.
2024-07-24 17:54:59 +08:00
Yuanqiang Liu aad1604046
[Torch] enhance fold of aten.squeeze.dim (#3558) 2024-07-24 14:13:48 +08:00
Ze Zhang d1e172f418
Register fake_quantize_cachemask ops and add their decompose patterns (#3556)
Test:

`cmake --build build --target check-torch-mlir-all`
2024-07-23 11:33:12 -07:00
Yuanqiang Liu 21ad890009
[Torch] enhance fold of aten.slice.Tensor (#3557)
so that it could support folding slice with any static shape.
2024-07-23 22:53:03 +08:00
zjgarvey 0fb8b017d8
Adds misc fixes for some padding related issues (#3528)
This patch adds a few misc pad op related changes:

1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4

Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
2024-07-11 20:01:45 -05:00
Ze Zhang d466d5b809
Register fake_quantize related ops (#3522)
Register `aten.fake_quantize_per_channel_affine` and
`aten.fake_quantize_per_tensor_affine.tensor_qparams` ops

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-07-05 11:02:03 -07:00
Christopher McGirr 7e6d76e997
[Torch] Fix torch.constant.int operation parsing (#3476)
Due to the custom operation parser, the print and parser were expecting
two different forms.

One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
2024-06-28 16:06:52 +02:00
Sambhav Jain 09f502667b
`AtenTensorOp::fold` should not fold when result type is not fully specified (#3494)
In one of our downstreams, we encountered an internal assertion failure
in an intermediate pass from `AtenTensorOp::fold` invocation:
```
external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed.
```

for this snippet in the IR:
```
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>}
...
    %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list<int>
    %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
```

Turns out this was
[fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719)
eventually (and we were on an old hash of torch-mlir). This PR submits
just the lit test for test coverage on that specific change:
```c++
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
  auto resultTy = dyn_cast<ValueTensorType>(getType());
  // lit test this
  if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
    return nullptr;
  ...
```
2024-06-24 15:22:50 -07:00
Peiming Liu ba16bad8c7
[torch-mlir] bump stablehlo/llvm version (#3471)
Update to llvm/llvm-project@5207632f86
Update to openxla/stablehlo@d41390c3a7
2024-06-18 16:59:53 -07:00
zjgarvey de28c8540b
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
2024-06-12 10:37:22 +05:30
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Angel Zhang 2e194e13d6
[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)
This PR fixes the bugs for `Torch::AtenOneHotOp` by:

1) Using `Torch::kUnknownSize` as the default value for `numClasses` in
   the pattern matching stage in `DecomposeAtenOneHotOp`
2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith`
3) Handling both `int` and `float` types for `off` and `on` values in
`TorchOnnxToTorch` conversion

It also includes:

1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`,
and
2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot`

**Dependencies**

This PR is dependent on #3334.
2024-05-22 17:19:08 +00:00
Aaron St George ba32b9cee7
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
2024-05-16 00:07:45 +08:00
zjgarvey 75d1d72059
Generalize Operand Quantization in FuseQuantizeOps (#3327)
This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
2024-05-12 20:49:59 -07:00
Benoit Jacob bce800a3f4
Integrate llvm-project at dabdec1001dc368373dd581cf72f37a440873ce3 (#3300)
Co-authored-by: Jacques Pienaar <jpienaar@google.com>
2024-05-08 14:43:06 -04:00
Vivek Khandelwal e60160d793
Revert "Decompose AtenNonzeroOp" (#3289)
Reverts llvm/torch-mlir#3281
2024-05-06 09:52:04 -07:00
Xida Ren (Cedar) 1af00e6040
Decompose AtenNonzeroOp (#3281)
This fixes some onnx lit tests not lowering to linalg in
https://github.com/nod-ai/SHARK-Turbine/issues/450
2024-05-05 21:59:25 +08:00
Ze Zhang 11cd7cd9e7
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues:

- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type

- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user

This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp


Lit test:

`cmake --build build --target check-torch-mlir-all`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-05-02 00:03:41 -07:00
Yuanqiang Liu aed2cf3351
[Torch] emit aten.__contains__.str_list and add folder (#3249) 2024-04-29 10:51:17 +08:00
Stella Laurenzo 5d4b803914 [NFC reformat] Run pre-commit on all files and format misc.
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.

Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
Yuanqiang Liu f173a06fa7
[Torch] emit aten.ne.str and add folder (#3242) 2024-04-28 00:58:50 +08:00
Yuanqiang Liu 634a796933
[Torch] fold aten.log (#3223) 2024-04-26 10:10:02 +08:00
Yuanqiang Liu b0ba3def93
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231) 2024-04-26 02:36:13 +08:00
Yuanqiang Liu fab2696489
[Torch] support aten.trunc (#3219)
decompose `trunc(x)` to `sign(x) * floor(abs(x))`
2024-04-24 14:32:33 +08:00
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
Xinyu Yang 6524838bcb
[Torch] Add general AdaptiveAvgPool2dOp decompose support (#3111)
Previously, it could only handle the situations where outputsize == (1,
1) or outputsize == (input_H, input_W). Now it supports all situations
where input_H % output_H== 0 && input_W % output_W == 0
2024-04-11 17:02:59 +08:00
Thomas Dietert 3c33dbd987
[MLIR][Torch] Canonicalize torch.from_i1 and torch.to_i1 (#3067)
When lowering `torch.aten.convolution`, it is expected that the
'transposed' argument is a torch.constant operation. In some cases, the
argument was a `from_i1` operation converting an `arith.constant`
operation into a torch.bool. This is not wrong semantically, but instead
of generalizing the legality of the `torch.aten.convolution` op, we
canonicalize `arith.constant` ops followed by `from_i1` ops to
`torch.bool` ops.

For example:
```
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.convolution'(0x124705b90) {
  %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : unimplemented: only constant transposed supported.      <-- Resolved by this PR
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported Scalar to Tensor like op
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported elementwise op
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported reduce op
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<stdin>:21:11: error: failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
    %17 = torch.operator "onnx.Conv"(%arg0, %0, %1) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [5 : si64, 5 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>) -> !torch.vtensor<[1,10,24,24],f32> 
          ^
<stdin>:21:11: note: see current operation: %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>
```

Additionally, we require the canonicalization of `to_i1` operating on a
torch.constant bool to an `arith.constant ... : i1` for the e2e tests to
pass successfully.
2024-04-01 14:25:51 -07:00
Yuanqiang Liu 0a581a97a7
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
2024-03-27 09:51:58 +08:00
Rob Suderman 14b548f968
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
2024-03-26 12:41:40 -07:00
Rob Suderman 0723584936
[torch] Add folder for torch.aten.*.Scalar comparisons (#3000)
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
2024-03-08 13:44:00 -08:00
Rob Suderman a86e89ecb5
[torch] Additional folders for shape computations (#2972)
A handful of operations are commonly used in shape calculations (slice,
concat, broadcast). Added these additional folders to better propagate
simple shape computations.
2024-03-04 11:46:49 -08:00
Rob Suderman 61f0a5facf
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
2024-03-01 21:41:12 -08:00
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Rob Suderman e30a083aff
[torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940)
We collapsed and broadcasted scatter indices to a single element
version. We should instead upport `tm_tensor.scatter`s support for
multiple indices and the implicitly broadcasted behavior. This avoids
the serialization and materializing a needlessly large indices tensor.
2024-02-27 11:46:57 -08:00
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
Stella Laurenzo 4446fa00d8
Migrate passes in TorchConversion to use FunctionOpInterface. (#2935)
This enables better re-use in downstreams which use different func
implementations and should have no impact on those that don't except in
opt pipelines if using the old form. With interfaces, explicit pipelines
via `--pass-pipeline=` must be used.
2024-02-20 08:54:02 -08:00
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman e80054a3cc
[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)
Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
2024-02-19 10:28:23 -08:00
Rob Suderman e9cdd6cbc5
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We
map these back to generics so they appropriately match to the high
performance versions. This is done for the attention operation.
2024-02-13 21:18:01 -08:00
Scott Todd d6e1d836ca
Drop torch attributes at the end of backend conversion. (#2876)
Fixes https://github.com/llvm/torch-mlir/issues/2866

Some backends / downstream projects expect that a "fully converted"
program has no remaining ops or attributes from the original dialect(s).
2024-02-13 14:32:02 -08:00
Rob Suderman c0f139be0f
[torch] Add `torch.aten.eq.Tensor` comparison folder (#2889)
Added a folded for a equals operator. This allows an equivalent
comparison folder, primarily for when shape computations occur small
size tensor.
2024-02-09 15:02:20 -08:00