Commit Graph

120 Commits (f4840ed886f39db5bcb3bf20d37e79f8c4657746)

Author SHA1 Message Date
Ze Zhang b3942ff984
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
2024-09-03 09:13:59 -07:00
Rob Suderman fd98476f77
[torch] Unpacking sometimes misses shape inference (#3609)
It is possible that the unpacked tensor does not match the same inferred
shapes. This is pretty common when ingesting form the `onnx` frontend.
2024-08-08 16:17:31 -07:00
Yuanqiang Liu 003b06dfa1
[Torch] enhance naryFolderHelper to support mixed dtypes (#3559)
* so that it could support like `i64 + f64 => f64`.
* also unify `aten.log`'s folder code to use `naryFolderHelper`.
2024-07-24 17:54:59 +08:00
Yuanqiang Liu aad1604046
[Torch] enhance fold of aten.squeeze.dim (#3558) 2024-07-24 14:13:48 +08:00
Yuanqiang Liu 21ad890009
[Torch] enhance fold of aten.slice.Tensor (#3557)
so that it could support folding slice with any static shape.
2024-07-23 22:53:03 +08:00
Sambhav Jain 09f502667b
`AtenTensorOp::fold` should not fold when result type is not fully specified (#3494)
In one of our downstreams, we encountered an internal assertion failure
in an intermediate pass from `AtenTensorOp::fold` invocation:
```
external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed.
```

for this snippet in the IR:
```
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>}
...
    %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list<int>
    %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
```

Turns out this was
[fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719)
eventually (and we were on an old hash of torch-mlir). This PR submits
just the lit test for test coverage on that specific change:
```c++
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
  auto resultTy = dyn_cast<ValueTensorType>(getType());
  // lit test this
  if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
    return nullptr;
  ...
```
2024-06-24 15:22:50 -07:00
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Aaron St George ba32b9cee7
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
2024-05-16 00:07:45 +08:00
Ze Zhang 11cd7cd9e7
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues:

- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type

- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user

This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp


Lit test:

`cmake --build build --target check-torch-mlir-all`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-05-02 00:03:41 -07:00
Yuanqiang Liu aed2cf3351
[Torch] emit aten.__contains__.str_list and add folder (#3249) 2024-04-29 10:51:17 +08:00
Stella Laurenzo 5d4b803914 [NFC reformat] Run pre-commit on all files and format misc.
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.

Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
Yuanqiang Liu f173a06fa7
[Torch] emit aten.ne.str and add folder (#3242) 2024-04-28 00:58:50 +08:00
Yuanqiang Liu 634a796933
[Torch] fold aten.log (#3223) 2024-04-26 10:10:02 +08:00
Yuanqiang Liu b0ba3def93
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231) 2024-04-26 02:36:13 +08:00
Yuanqiang Liu fab2696489
[Torch] support aten.trunc (#3219)
decompose `trunc(x)` to `sign(x) * floor(abs(x))`
2024-04-24 14:32:33 +08:00
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
Yuanqiang Liu 0a581a97a7
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
2024-03-27 09:51:58 +08:00
Rob Suderman 0723584936
[torch] Add folder for torch.aten.*.Scalar comparisons (#3000)
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
2024-03-08 13:44:00 -08:00
Rob Suderman a86e89ecb5
[torch] Additional folders for shape computations (#2972)
A handful of operations are commonly used in shape calculations (slice,
concat, broadcast). Added these additional folders to better propagate
simple shape computations.
2024-03-04 11:46:49 -08:00
Rob Suderman 61f0a5facf
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
2024-03-01 21:41:12 -08:00
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman e80054a3cc
[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)
Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
2024-02-19 10:28:23 -08:00
Rob Suderman c0f139be0f
[torch] Add `torch.aten.eq.Tensor` comparison folder (#2889)
Added a folded for a equals operator. This allows an equivalent
comparison folder, primarily for when shape computations occur small
size tensor.
2024-02-09 15:02:20 -08:00
Rob Suderman 7d33ba69ac
[torch] Folder for torch.aten.select.int for splat cases (#2890)
If the input or result is a splat value we can just constant fold the
result. This is common for shape computations and can help with shape
inference.
2024-02-09 14:02:54 -08:00
Dave Liddell 23647ab2d1
[torhc] aten.index_select folder (#2871)
Folds aten::index_select ops under the following conditions:

1. If the input and output are the same shape, the indexing operation is
a NOP, so just return the input.
2. If the input has shape <1x1x...xNx...x1> (all 1's except for one
dim), and the output shape is <1x1x...x1> (all 1's), then there is a
single index, so extract the single element value and return a tensor
with that value.

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-02-07 16:17:15 -08:00
Xida Ren (Cedar) fc04bc7ee9
[torch] AtenSliceOp folder that produces splat results (#2869)
Includes `slice` folder and lit tests

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-07 19:00:46 +00:00
Xida Ren (Cedar) cc06391630
AtenSortOp Folder (#2864)
A chunk off

https://github.com/llvm/torch-mlir/pull/2856
https://github.com/llvm/torch-mlir/pull/2860

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
2024-02-06 21:12:12 +00:00
Dave Liddell 1cb14f6879
Rob's atenTensor folder (#2867)
If a tensor is initialized by a list with a single constant integer,
this folder turns it into a torch.vtensor.literal

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-02-05 17:10:42 -08:00
Xida Ren (Cedar) 24b8c8672a
[torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849)
So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor

---------

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-02 10:46:33 -08:00
Aaron St George 4c557847bd
Don't fold `aten.detach` if result isn't same type as input. (#2824)
We were seeing some assertion failures after some checks around folders
were tightened up in LLVM:
https://github.com/llvm/llvm-project/pull/75887 . This PR essentially
moves the logic that used to be applied at the LLVM level into the
folder, which seems to be the suggested fix.

I'm not sure if the IR that caused issues for us _should_ be valid?
```
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
```
A better fix might be to create a verifier ensuring the result of
`aten.detach` has the same type as its operand.

---------

Co-authored-by: aaron-stgeorge <aaron.stgeorge@getcruise.com>
2024-01-30 09:45:51 -08:00
Aart Bik fe836ceebf
[torch-mlir][test] cleanup trailing whitespace in mlir files (#2806) 2024-01-25 14:24:13 -08:00
Han-Chung Wang 10acea71be
Bump LLVM to llvm/llvm-project@0cb024b (#2753)
- Add fixes for
af78e5daf0
- Add fixes for
bb6d5c2200
2024-01-15 07:12:12 -08:00
Zhekun(Josh) Zhang d67afa9e95
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543) 2023-11-21 13:26:17 +08:00
Yuanqiang Liu 3ab790c50a
[Torch Dialect] add canonicalize for aten.numel (#2562) 2023-11-11 12:16:53 +08:00
Quinn Dawkins ae72eec224
Improve aten.broadcast_to folder when in strict symbol mode (#2504)
Strict symbolic shapes allow us to assume numpy-style dynamic broadcasts
never occur. This allows us to strengthen the folder for broadcasts to
cases where the rank is the same and all shapes match (including dynamic
sentinel values).
2023-10-05 09:02:10 -04:00
Quinn Dawkins 1fc4314b62
Add folder for aten.broadcast_to on unchanged static shapes (#2421) 2023-09-01 14:50:34 -04:00
JianzheXiao 17d02811d5
[Torch Dialect] add folder for aten.any.bool (#2388)
* update

* update

* update

* update

* update

* update

* update
2023-08-30 17:29:03 +08:00
Jiawei Wu 4c9d234b01
revert canonicalizer for PrimListConstructOp (#2408) 2023-08-22 09:18:39 +08:00
Jiawei Wu 4c12aceb81
[Torch-Dialect] add canonicalizer for prim::ListConstruct op (#2306)
[Torch-Dialect] add canonicalizer for prim::ListConstruct op
2023-08-08 10:28:11 +08:00
Alexandre Rames 1e468e8294 Fix canonicalization of `torch.prim.TupleUnpack`. 2023-07-20 20:08:46 +02:00
Matthias Gehre 64d7626a52
Fixes for split tensor and slice (#2314)
* RecomposeComplexOps: Remove dead slice op

* lib/Dialect/Torch/IR/TorchOps.cpp: Fold slice ops even when they are on non-value tensors

* lib/Conversion/TorchToTosa/TorchToTosa.cpp: Fix slice start/end out of range/none

* lib/Dialect/Torch/IR/TorchOps.cpp: AtenSliceTensorOp::fold: Fold slices that go from 0:int_max

* More tests for aten.split.Tensor
2023-07-20 09:53:54 +02:00
Jiawei Wu 3f843c8fd9
[torch-dialect] fix aten.type_as op's folder (#2283)
[torch-dialect] fix torch.type_as op's folder by decomposing it to prim.dtype + aten.to_dtype
2023-07-20 09:51:58 +08:00
Jiawei Wu c7fa42b7d3
[Torch Dialect] Add canonicalizer for aten.to.other op (#2273)
Canonicalize aten.to.other to prim.device + prim.dtype + aten.to.device
Co-authored-by: wujiawei.aml <wujiawei.aml@bytedance.com>
2023-06-30 09:43:08 +08:00
Yuanqiang Liu 449cfb8375
[Torch Dialect] add more scalar op folders (#2265) 2023-06-29 10:37:13 +08:00
Yuanqiang Liu 1ea2b57ab7
[Torch Dialect] add folder for aten.add (#2264)
* [Torch Dialect] add folder for aten.add

* update

* update

* update
2023-06-27 10:55:28 +08:00
Yuanqiang Liu 96b14e952e
[Torch Dialect] Support aten.device.with_index (#2254) 2023-06-23 01:07:14 +08:00
Yuanqiang Liu 7c6961bcbf
[Torch Dialect] Support aten.cuda and add canonicalizer for aten.cuda (#2231) 2023-06-14 09:56:39 +08:00
Matthias Gehre 27a3d09917
Torch: Fold RuntimeAssertOp when condition is true (#2198) 2023-06-09 19:06:25 +08:00