Commit Graph

1033 Commits (main)

Author SHA1 Message Date
Christopher McGirr 7e6d76e997
[Torch] Fix torch.constant.int operation parsing (#3476)
Due to the custom operation parser, the print and parser were expecting
two different forms.

One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
2024-06-28 16:06:52 +02:00
Aart Bik 1f73895f93
[torch-mlir] bump to llvm/llvm-project@9b78ddf3b2 (#3491)
This bump triggered an upstream assert. Includes a WAR for #3506.

Also includes several things I needed to do to repro:

* When TORCH_MLIR_TEST_CONCURRENCY=1, test runs will be printed.
* Added TORCH_MLIR_TEST_VERBOSE=1 handling to enable verbose mode
(useful on CI).

---------

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
2024-06-27 19:28:02 -07:00
Matthias Gehre 6678e1a256
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.

We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.

For example
```
module {
  func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
                        %arg1: !torch.vtensor<[336,168,3,3],f32>, 
                        %arg2: !torch.vtensor<[336],f32>) 
                        -> !torch.vtensor<[32,336,56,56],f32> {
    %false = torch.constant.bool false
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 
    : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
      !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
   -> !torch.vtensor<[32,336,56,56],f32>
    return %3 : !torch.vtensor<[32,336,56,56],f32>
  }
}
```
would result in
```
[...]
  %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
  %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
    ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
    outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 08:43:10 +02:00
zjgarvey d2bc70f188
[TorchToLinalg][ONNX] Add Basic Determinant Support (#3481)
This adds support for a few ops:

- torch.linalg_det
- torch._linalg_det (if the LU and pivot returns are unused)
- onnx.Det

An scf loop is used, since the row reduction algorithm applied here has
some loop-carried dependencies.
The current support being added here is very basic, and only works if no
permutations are required during row reduction, and assumes the matrices
are non-singular.
2024-06-25 13:34:19 -05:00
zjgarvey 368fabf0c1
[ONNX] Basic Support for DeformConv (#3469)
This adds a torchvision op to torch-mlir and a path from onnx.DeformConv
to torchvision.deform_conv2d.

I'm not implementing the torch->linalg lowering for the torchvision op
yet, but posting this PR to get feedback on some of the choices being
made here and to flesh out the onnx frontend a bit.
2024-06-25 12:16:51 -05:00
zjgarvey e346c911f7
[ONNX] Add basic support for RoiAlign (#3493)
This adds an onnx->torch conversion for onnx.RoiAlign into
torchvision.roi_align or torchvision.roi_pool, and adds those two
torchvision ops to torch-mlir.
2024-06-25 11:02:45 -05:00
Vinayak Dev 02340408b7
[torch] Add OnnxToTorch lowering for Onnx.STFT op (#3492)
Adds OnnxToTorch lowering for `Onnx.STFT` op.
2024-06-25 19:00:45 +05:30
Branko Trifkovic 98c6971a01
Implement lowering of torch.aten.triu_indices (#3451)
Closes
[nod-ai/SHARK-Turbine/issues/709](https://github.com/nod-ai/SHARK-Turbine/issues/709)

---------

Co-authored-by: Branko Trifkovic <branko.trifkovic@syrmia.com>
2024-06-21 16:16:38 -07:00
Matthias Gehre acd57a3520
Support fake_quantize_per_tensor_affine_cachemask (#3477)
Add a new op with shape/dtypes and decompose into
`fake_quantize_per_tensor_affine` when the second result is unused.

The xfail_set change is on ONNX because torch cannot export this op to
ONNX.
2024-06-21 07:15:31 +00:00
zjgarvey 694210f429
[TorchToLinalg] Fix Quantized Convolution Accumulator Type (#3459)
1. truncates zero-points to i32
2. modifies the default accumulator type for i8 from i64 to i32. 
3. now uses the input dtype to infer accumulator dtype.
2024-06-20 13:54:20 -07:00
Xinyu Yang c7d52f63b4
[stablehlo] add aten::_int_mm lowering (#3474)
as title
2024-06-20 16:10:31 +08:00
Branko Trifkovic 676fa8cc09
Implement lowering of torch.aten.renorm (#3388)
Closes
[nod-ai/SHARK-Turbine/issues/689](https://github.com/nod-ai/SHARK-Turbine/issues/689)

---------

Co-authored-by: Branko Trifkovic <branko.trifkovic@syrmia.com>
2024-06-17 10:40:57 -07:00
ptrifunovic98 4555629246
Implement lowering of torch.aten.kthvalue (#3360)
Closes
[nod-ai/SHARK-Turbine#620](https://github.com/nod-ai/SHARK-Turbine/issues/620)
2024-06-15 11:18:39 +05:30
Arham Khan 09c988046c
[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op (#3380)
This implements the Onnx.NegativeLogLikelihoodLoss op using the
signature provided
[here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html)
by replacing it with a `NLLLossForward` op.

Additionally, I included a helper function `get_loss_reduction_enum` to
convert from a string `reduction` parameter to the corresponding
intended integer value since this is an operation that will be reused
for any loss function module. This differs from `get_reduction_enum` in
`TorchUpstream.cpp` which handles the `reduce` parameter from
`scatter_reduce` type operations.
2024-06-14 22:01:11 +05:30
Xinyu Yang 6f94c7b0aa
[Torch] Add support for Meshgrid (#3462) 2024-06-14 23:59:08 +08:00
Vinayak Dev 39d882f7c9
[torch] Add OnnxToTorch lowering for the Col2Im op (#3424)
Adds OnnxToTorch lowering for the `onnx.Col2Im` op.
2024-06-13 08:42:06 +00:00
Lei Zhang 77d7f64472
Update to llvm/llvm-proect@27ac46e6be (2024-6-12) (#3454)
This would require to bump stablehlo at the same time.
2024-06-12 19:34:01 -07:00
zjgarvey de28c8540b
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
2024-06-12 10:37:22 +05:30
Yuanqiang Liu 689efc8917
[Torch] fix toBuiltinTensor() (#3415)
* Let `toBuiltinTensor()` reflects the original dtype of
`!torch.vtensor`.
* Backend handles dtype conversion themselves.
2024-06-08 09:36:32 +08:00
Rob Suderman 75af64fc12
[torch] Add support for f8 types for linalg conversion (#3436)
Linalg conversion requires mapping for f8 types
2024-06-07 13:59:38 -07:00
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Vivek Khandelwal 72837fbb3d
build: manually update PyTorch version (#3340)
Set PyTorch and TorchVision version to nightly release 2024-05-14.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-06 22:23:40 +05:30
Vivek Khandelwal 661be2d5b0
[MLIR][Torch] Add TorchToLinalg lowering for AtenAvgPool3dOp (#3030)
This commit also fixes the average pool op' test failing for
OnnxToLinalg lowering.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-04 22:12:34 +05:30
Yuanqiang Liu 50f7103098
[Stablehlo] support uint8 (#3367)
Support lowering unsigned integer type to stablehlo as discussed in
https://github.com/llvm/torch-mlir/pull/2184.

The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
    %0 = tensor.empty() : tensor<3xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %2 = arith.uitofp %in : i8 to f32
      linalg.yield %2 : f32
    } -> tensor<3xf32>
    return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
    %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
    return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
2024-06-04 09:04:59 +08:00
zjgarvey 8995c90879
[TorchToLinalg] add support for quantized group conv (#3341)
This addresses 7 of the model failures I'm seeing in the test suite. See
[Shark-Turbine issue
#566](https://github.com/nod-ai/SHARK-Turbine/issues/566).

Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream
before merging this. See [llvm-project PR #92136
](https://github.com/llvm/llvm-project/pull/92136).

A small additional expansion to operand quantization is included in this
patch to address a model failure that occurs when unblocking the
quantized group convolutions in one of these onnx models.
2024-06-03 21:57:44 +05:30
Vivek Khandelwal 6382dbbcc0
[ONNX] Add OnnxToTorch lowering for SpaceToDepth op (#3393)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-06-03 20:29:39 +05:30
Xinyu Yang 285b087a5d
[Torch] Emit rrelu and decompose it (#3250)
as title
2024-06-03 19:25:52 +08:00
Xinyu Yang 267052df2a
[Torch] decompose AtenLerpTensorOp (#3251)
as title
2024-06-03 15:25:09 +08:00
Xinyu Yang 23b53050de
[Torch]Support conv_transpose1d and conv_transpose3d (#3286)
1. Support conv_transpose1d and conv_transpose3d
2. Fix bugs of convertTransposedConv func in
lib/Conversion/TorchToStablehlo/Linear.cpp
2024-06-03 15:11:12 +08:00
Rob Suderman afca88a058
[NFC] Change to *cast instead of .*cast variants (#3405)
Member casts have been deprecated. Changing over a bunch of the member
cast calls to the global templated variants to remove deprecation
warnings.
2024-05-30 23:45:13 -07:00
Yuanqiang Liu 4e05e2cd1e
[Torch] support recompose of aten.split.with_sizes and aten.tensor_sp… (#3401)
…lit.sections

* support recompose to aten.split.with_sizes and
aten.tensor_split.sections
* fix recompose of aten.chunk
2024-05-31 09:56:47 +08:00
zjgarvey 074098d20c
Modifies onnx resize lowering to fix numerical issues (#3381)
Updates:

- some unsupported modes are now going to report a match failure for
unsupported coordinate transformation modes.
- fixes a bug that was introduced in the last patch for resize (my
bad...)
- uses actual x and y coordinates for computing weights in bilinear
interpolation (rather than eps modified values)
- slightly simplifies the bilinear interpolation payload for readability
and performance
- passes coordinate transformation mode information from an onnx.Resize
op to the mode string for the aten._interpolate op. This allows us to
perform custom logic in the torch->linalg lowering to support
onnx.Resize options without losing the default behaviors of the
interpolate op.
2024-05-30 20:34:37 -04:00
penguin_wwy 1f544c37d0
[NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
Yuanqiang Liu e0a5adb1db
[Torch] fix aten.linear's decomposition (#3391)
* support aten.linear with more rank.
2024-05-27 15:49:50 +08:00
Yuanqiang Liu 5bb1a65ec9
[Stablehlo] refactor reduction lowering and support aten.amin (#3383)
* implement detailed lowering template pattern
`ConvertAtenReduceAllDimsOp` and `ConvertAtenReduceKeepDimOp`
* support `aten.amin`'s lowering.
2024-05-23 20:40:20 +08:00
Angel Zhang 2e194e13d6
[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)
This PR fixes the bugs for `Torch::AtenOneHotOp` by:

1) Using `Torch::kUnknownSize` as the default value for `numClasses` in
   the pattern matching stage in `DecomposeAtenOneHotOp`
2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith`
3) Handling both `int` and `float` types for `off` and `on` values in
`TorchOnnxToTorch` conversion

It also includes:

1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`,
and
2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot`

**Dependencies**

This PR is dependent on #3334.
2024-05-22 17:19:08 +00:00
Xinyu Yang 4d7cdba4bf
[Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTriuOp (#3330)
I am trying to eliminate 'getWithLeastStaticInformation' in
DecomposeAtenTriuOp. Could you provide me with some suggestions?
@qingyunqu @zjgarvey 
See issue https://github.com/llvm/torch-mlir/issues/3312
2024-05-22 23:16:57 +08:00
Sambhav Jain 6e485574e5
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376)
Discord Thread:
https://discord.com/channels/636084430946959380/1238330633328005243

## Context: 

[This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61)
was updated to support e2e tests for the TorchDynamo frontend in
Torch-MLIR, where we run FX decompositions and import the FX IR to
generate Torch dialect, followed by
`torch-function-to-torch-backend-pipeline`, skipping only the shape/type
refinement for now. However, we should be able to skip many of the torch
simplification passes, as depicted in the [frontend
roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png).

Based on IREE's TorchDynamo
[pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29),
the only two passes we seem to require are: `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. This is inline with our findings as well
based on initial exploration.

This PR creates a dedicated frontend simplification pipeline for
TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to
ensure we're not regressing by removing many of the passes that were
historically needed for TorchScript.

One notable change here is that we do not call the
`LowerToBackendContractPass` anymore, which used to call
`TorchSimplificationPipeline` iteratively until VerifyBackendContract
was clean. Some of this was required for the shape/type refinement to
converge, which seems a non-issue for Dynamo frontend. Do we anticipate
this (the iterative invocation of TorchSimplificationPipeline followed
by VerifyBackendContract) to be worth retaining in the Dynamo frontend
pipeline? If so, I can make those changes, PLMK.
2024-05-22 05:23:18 -07:00
Yuanqiang Liu 8814d0ae64
[Torch] emit aten.dot and canonicalize it to aten.matmul (#3361)
* canonicalize `aten.dot` to `aten.matmul`
2024-05-18 22:45:14 +08:00
Xinyu Yang 7faba75696
[Torch] Decompose AtenMaskedScatterOp (#3353)
Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>
2024-05-16 15:27:25 +08:00
Xinyu Yang a9edefb3cf
[Torch] Fix AtenSliceTensorOp::fold (#3345) 2024-05-16 11:42:43 +08:00
penguin_wwy 405f884522
[stablehlo] verify stablehlo backend contract (#3338) 2024-05-16 11:03:43 +08:00
Peiming Liu ccb772cd0f
[sparse] propagate sparsity properly when decompose torch operations. (#3318) 2024-05-15 10:09:27 -07:00
Aaron St George ba32b9cee7
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
2024-05-16 00:07:45 +08:00
Xinyu Yang 6b95dd461d
[Torch] Fix PrimNumToTensorScalarOp::fold (#3339)
In constant folding progress, a new constant op will be created
according to the origin op's result type.

See the code in TorchDialect.cpp.

```cpp
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
                                             Attribute value, Type type,
                                             Location loc) {
  if (auto integerType = dyn_cast<Torch::IntType>(type))
    return builder.create<Torch::ConstantIntOp>(loc, cast<IntegerAttr>(value));

  if (auto floatType = dyn_cast<Torch::FloatType>(type))
    return builder.create<Torch::ConstantFloatOp>(loc, cast<FloatAttr>(value));

  if (auto numberType = dyn_cast<Torch::NumberType>(type)) {
    if (auto floatValue = dyn_cast<mlir::FloatAttr>(value)) {
      return builder.create<Torch::ConstantNumberOp>(loc, floatValue);
    } else if (auto intValue = dyn_cast<mlir::IntegerAttr>(value)) {
      return builder.create<Torch::ConstantNumberOp>(loc, intValue);
    }
  }

  if (isa<Torch::BoolType>(type)) {
    return builder.create<Torch::ConstantBoolOp>(loc, cast<IntegerAttr>(value));
  }

  if (isa<Torch::NoneType>(type))
    return builder.create<ConstantNoneOp>(loc);

  if (auto stringAttr = dyn_cast<StringAttr>(value))
    return builder.create<ConstantStrOp>(loc, stringAttr);

  if (auto elementsAttr = dyn_cast<ElementsAttr>(value)) {
    // Only !torch.vtensor can be constant folded. !torch.tensor has
    // non-trivial aliasing semantics which prevent deduplicating it.
    assert(isa<ValueTensorType>(type) && "should be a vtensor type!");
    return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
  }

  return nullptr;
}
```
So when the op has a tensor result type, it must be "ValueTensorType"
due to the **assert** statement. However, many fold methods in
TorchOps.cpp only have a judgment of "BaseTensorType".
2024-05-15 20:54:19 +08:00
zjgarvey 911e723581
Expands Q Commuting Ops (#3332)
After running the model tests in SHARK-TestSuite, I noticed a few model
failures due to half-fusion.

Notably, RDN_pytorch_vaiq_int8 had a depth=5 convolution chain with
multiple AtenViewOp's.
2024-05-13 11:01:53 -07:00
zjgarvey 75d1d72059
Generalize Operand Quantization in FuseQuantizeOps (#3327)
This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
2024-05-12 20:49:59 -07:00
NeverRaR 1d4859699b
MaxPool1d lowering to linalg (#3295)
Co-authored-by: root <root@i32b01216.sqa.eu95>
2024-05-10 22:05:26 +05:30
penguin_wwy 64b59c7fc3
[FxImporter] Eliminate the dependency on the refinement pass (#3309) 2024-05-10 02:44:36 +08:00
aldesilv ec6d7aa5d2
OnnxToTorch lowering resize op (#3013)
https://github.com/nod-ai/SHARK-Turbine/issues/358
adds a lowering from onnx to linalg for bilinear and nearest resize with
support for using scales or sizes to get resize shape. uses coordinate
transform half pixel for bilinear mode and asymmetrical for nearest
mode. See
https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added
two passes -- one for bilinear and the other for nearest.
2024-05-08 21:35:03 +00:00
Jiawei Wu 346a536c9f
[Torch Dialect] decompose all index_put-like op to aten.index_put.hacked_twin for stricter semantics (#3071)
This PR decomposes all index_put-like op to aten.index_put.hacked_twin for stricter semantics, i.e., no None index in indices argument.
2024-05-08 22:44:57 +08:00
Xinyu Yang abef114c0c
[torch] emit aten.Softshrink and aten.Hardshrink (#3248)
as title
2024-05-08 15:20:45 +08:00
Vivek Khandelwal e60160d793
Revert "Decompose AtenNonzeroOp" (#3289)
Reverts llvm/torch-mlir#3281
2024-05-06 09:52:04 -07:00
Xida Ren (Cedar) 1af00e6040
Decompose AtenNonzeroOp (#3281)
This fixes some onnx lit tests not lowering to linalg in
https://github.com/nod-ai/SHARK-Turbine/issues/450
2024-05-05 21:59:25 +08:00
Ze Zhang 11cd7cd9e7
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues:

- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type

- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user

This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp


Lit test:

`cmake --build build --target check-torch-mlir-all`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-05-02 00:03:41 -07:00
Xida Ren (Cedar) 315dc6c3e3
[torch] `aten.eye` should use dynamic dims when no static dims are available (#3202)
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-04-30 17:41:03 +00:00
Xinyu Yang f32ada993d
[Stablehlo] Improve the lowering of pool op in stablehlo (#3259)
1. Handle case stride == None
2. add avgpool3d maxpool1d  maxpool3d lowering
2024-05-01 00:06:13 +08:00
Rob Suderman db6721084a
Integrate LLVM at llvm/llvm-project@593f6fdcb4 (#3260) 2024-04-29 12:01:40 -07:00
Vivek Khandelwal b1e2241479
[ONNX] Fix Onnx.Selu lowering and canonicalizer for IntImplicit op (#3221)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-29 04:00:01 +00:00
Yuanqiang Liu aed2cf3351
[Torch] emit aten.__contains__.str_list and add folder (#3249) 2024-04-29 10:51:17 +08:00
Xinyu Yang 5684dc0441
[Torch] emit aten.celu and decompose it (#3247)
CELU(x)=max(0,x)+min(0,α∗(exp(x/α)−1))
2024-04-28 17:23:40 +08:00
Yuanqiang Liu 46c0f3cad0
[Torch] emit aten.log_sigmoid and decompose it to log(sigmoid) (#3246) 2024-04-28 11:47:43 +08:00
Stella Laurenzo 5d4b803914 [NFC reformat] Run pre-commit on all files and format misc.
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.

Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
penguin_wwy 6679728c56
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Like #3130, gradually replace the deprecated code

https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
2024-04-27 14:00:56 -07:00
Yuanqiang Liu f173a06fa7
[Torch] emit aten.ne.str and add folder (#3242) 2024-04-28 00:58:50 +08:00
Yuanqiang Liu 634a796933
[Torch] fold aten.log (#3223) 2024-04-26 10:10:02 +08:00
Aart Bik 2eac8a992f
[torch-mlir][sparse] sparse tensor dialect is a legal dialect (#3227) 2024-04-26 02:36:42 +08:00
Yuanqiang Liu b0ba3def93
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231) 2024-04-26 02:36:13 +08:00
Yuanqiang Liu fab2696489
[Torch] support aten.trunc (#3219)
decompose `trunc(x)` to `sign(x) * floor(abs(x))`
2024-04-24 14:32:33 +08:00
Xinyu Yang 4da3d714cc
[Torch] Support AtenProdOp on linalg and stablehlo (#3215) 2024-04-24 11:14:04 +08:00
zjgarvey a8ba865fca
[torch] Adds Quantization Support for `aten.relu` (#3177)
A choice was made to quantize the return type of Relu with a scale and
zero point copied from the input's quantization scheme. With this
choice, the torch-to-linalg conversion of quantized Relu essentially
computes max(input, zeroPoint) in the elementwise payload.
2024-04-23 11:01:36 -07:00
Yuanqiang Liu db3842f2e8
[Stablehlo] support lowering sinh & cosh to stablehlo (#3213) 2024-04-23 19:54:58 +08:00
penguin_wwy e5bdd71baf
[Torch] Emit and decompose prims.iota op (#3132) 2024-04-21 19:45:01 -07:00
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
Xinyu Yang d4313eed4a
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
Xinyu Yang d2ba956e69
[Torch] Support Aten_CastLongOp. (#3160)
By canonicalize Aten_CastLongOp into AtenToDtypeOp
2024-04-17 21:58:32 +08:00
zjgarvey 5e564b5864
Adds Some Quantization Support for AtenMatmulOp (#3147)
1. onnx.MatMulInteger now converts to aten.matmul instead of aten.mm
2. aten.matmul, for ranks >=2, now allows quantized inputs and will
lower to linalg::quantized_matmul or linalg::quantized_batch_matmul.
3. added AtenMatmulOp to the FuseQuantizeOps rewrite patters
QuantizeOperands, QuantizeTransposedOperands, and QuantizeAccumulator
4. added several tests, including some to test AtenMmOp with varying
quantization signed-ness.
5. a quantized matmul mat-vec test is added to verify the failure to
lower to linalg; cleaned of out-of-date code related to common
torch-mlir lowering xfails.
6. in debugging a real model with quantized matmuls, I found a bug on
the scalarize-shapes pass which resulted from the aten.full op folder
returning an incompatible result type. This is fixed by the small change
here to
[lib/Dialect/Torch/IR/TorchOps.cpp](https://github.com/llvm/torch-mlir/compare/main...zjgarvey:torch-mlir:MatMulIntegerFix?expand=1#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4f).
2024-04-15 16:06:47 -07:00
IanWood1 5708ee7ec9
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)
- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
2024-04-15 13:45:10 -07:00
zjgarvey 197ef4224b
Avoid Type Mismatch in Slice Folder (#3154)
Fixes issue #3153
2024-04-12 11:43:45 -07:00
penguin_wwy d4a30b7e67
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130)
We should prefer functional style as the method style is deprecated
https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecated
(https://mlir.llvm.org/deprecation/)
2024-04-11 06:47:35 -07:00
Xinyu Yang 308c45e61a
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
    def forward(self, x):
        if len(x.shape) == 5:
            b0, t, c0, h0, w0 = x.shape
            b, c, h, w = torch.mul(b0, t), c0, h0, w0
        else:
            b1, c1, h1, w1 = x.shape
            b, c, h, w = b1, c1, h1, w1
        res = torch.reshape(x, [b, c, h, w])
        return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
Xinyu Yang 6524838bcb
[Torch] Add general AdaptiveAvgPool2dOp decompose support (#3111)
Previously, it could only handle the situations where outputsize == (1,
1) or outputsize == (input_H, input_W). Now it supports all situations
where input_H % output_H== 0 && input_W % output_W == 0
2024-04-11 17:02:59 +08:00
Yuanqiang Liu 88533b1968
[Stablehlo] fix aten.arange's lowering to stablehlo (#3138)
* promote to f64 to do division, avoid division on i64 (floor div)
* refactor torch-to-stablehlo-pipeline
2024-04-11 15:55:56 +08:00
Xinyu Yang 5eb0cf9104
[Torch] Add decompose of AtenToPrimDeviceOp (#3131)
As device information isn't relevant to torch-mlir
2024-04-10 22:26:48 +08:00
Xinyu Yang 42a16fa912
[Torch] Support Aten_CastFloatOp. (#3115)
By canonicalize Aten_CastFloatOp into AtenToDtypeOp
2024-04-09 11:06:53 +08:00
Xinyu Yang 84c24e5771
[Torch] Support Aten__And__ScalarOp (#3114) 2024-04-08 20:24:17 +08:00
Yuanqiang Liu 2c56ef9252
[Torch Dialect] canonicalize aten.sign to aten.sgn (#3112)
* `aten.sign` is a sub-set of `aten.sgn` (`aten.sgn` support complex
type).
2024-04-08 20:05:42 +08:00
Yuanqiang Liu 43d54efd14
[cmake] link TorchMLIRTorchConversionPasses to TorchMLIRConversionPasses (#3113)
* as that `TorchMLIRTorchConversionPasses` missing dependencies of
`TorchMLIRTorchToStablehlo` and `TorchMLIRTorchToTensor`.
* use `TorchMLIRConversionPasses` instead of scattered targets.
2024-04-08 14:44:34 +08:00
Vivek Khandelwal 7e778e2179
build: manually update PyTorch version (#3094)
Set PyTorch and TorchVision version to nightly release 2024-04-01.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-03 10:48:37 +05:30
Rob Suderman f97cd4893f
[torch] Improve shape inference for dynamic shapes (#3091)
Shapes can be processed as tensors to represent the set of dimensions.
As reshapes take a list of scalars this can result in a single dynamic
dimension blocking the adjacent static dimensions.

This pass attempts to de-couple tensor computations related to shapes
and propagate values to better support lowering scalar tensor
computations.
2024-04-02 16:19:57 -07:00
zjgarvey 40e762ca42
Adds result types to a prelu decomp (#3098)
This adds explicit result types instead of relying on shape/dtype
computations.

Solves a regression issue with IREE: #3092
2024-04-02 11:41:56 -07:00
Yuanqiang Liu 6cbb2f7ae0
[Stablehlo] add stablehlo-canonicalize-dynamism when lowering (#3097)
so that many stablehlo e2e testcases could pass
2024-04-02 22:47:24 +08:00
zjgarvey 532d297c46
[ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089)
See the related issues here:
[SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556)

1. Adds uint8 casting to onnx.Cast op
2. Fixes an issue with onnx.DequantizeLinear when the scale comes with
shape [1].
3. Adds support for unsigned types in an AtenItemOp folder
4. Adds a simpler quantized model for easier debugging
5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm]
patterns to [transpose -> quant -> mm].
6. Moved some xfails that are still not passing, but for different
reasons than onnx.cast failures.
2024-04-01 16:21:05 -07:00
Thomas Dietert 3c33dbd987
[MLIR][Torch] Canonicalize torch.from_i1 and torch.to_i1 (#3067)
When lowering `torch.aten.convolution`, it is expected that the
'transposed' argument is a torch.constant operation. In some cases, the
argument was a `from_i1` operation converting an `arith.constant`
operation into a torch.bool. This is not wrong semantically, but instead
of generalizing the legality of the `torch.aten.convolution` op, we
canonicalize `arith.constant` ops followed by `from_i1` ops to
`torch.bool` ops.

For example:
```
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.convolution'(0x124705b90) {
  %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : unimplemented: only constant transposed supported.      <-- Resolved by this PR
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported Scalar to Tensor like op
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported elementwise op
  } -> FAILURE : pattern failed to match

  * Pattern : 'torch.aten.convolution -> ()' {
    ** Failure : not a supported reduce op
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<stdin>:21:11: error: failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
    %17 = torch.operator "onnx.Conv"(%arg0, %0, %1) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [5 : si64, 5 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>) -> !torch.vtensor<[1,10,24,24],f32> 
          ^
<stdin>:21:11: note: see current operation: %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>
```

Additionally, we require the canonicalization of `to_i1` operating on a
torch.constant bool to an `arith.constant ... : i1` for the e2e tests to
pass successfully.
2024-04-01 14:25:51 -07:00
Xinyu Yang da88efad89
[Torch] Fix bug of DecomposeAtenSelectIntOp (#3087)
Fix bug of DecomposeAtenSelectIntOp. Because it may use resultTy when
resultTy has not been inferred.

```
    auto resultTy = op.getType().cast<BaseTensorType>();
    if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
      rewriter.replaceOp(op, slice);
      return success();
    }

```

So I add restriction.
2024-04-01 21:25:02 +08:00
Xinyu Yang 40008b025a
[Torch] Support prelu decomposition (#3069) 2024-03-29 08:05:00 +08:00
Xinyu Yang e6e7689a24
[Torch] support decompose aten.einsum with ellipsis slicing (#3056) 2024-03-27 12:42:10 -07:00
Yuanqiang Liu 0a581a97a7
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
2024-03-27 09:51:58 +08:00
Rob Suderman 14b548f968
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
2024-03-26 12:41:40 -07:00
schnkmwt 1fcbfa87ec
Implement linalg lowering of diag_embed torch op (#2885)
This PR adds lowering of diag_embed to linalg dilect.
Tracked in https://github.com/nod-ai/SHARK-Turbine/issues/288

---------

Co-authored-by: sachink <sachink@xilinx.com>
2024-03-22 16:32:50 -07:00
zjgarvey 99b3a5f117
Converts all Adaptive Pooling Ops to Linalg (#2808)
The previous conversions for AtenAdaptiveAvgPool1dOp and
AtenAdaptiveMaxPool2dOp are refactored into a general templated
conversion that works for all of the AtenAdaptive...PoolNdOp's.

New support is added for the following ops:

1. AtenAdaptiveMaxPool1d
2. AtenAdaptiveMaxPool3d
3. AtenAdaptiveAvgPool3d

Support is also provided for passing inputs without batch dimensions.
For example, applying adaptive_avg_pool2d to an input tensor of rank 3.

After [pytorch #118162](https://github.com/pytorch/pytorch/pull/118162)
gets down to torch-mlir, I'll add a test for AdaptiveMaxPool1d with
return_indices (which will pass with that upstream fix).

---------

Co-authored-by: James Newling <james.newling@gmail.com>
2024-03-22 11:05:20 -07:00
Yuanqiang Liu 8b96727d0d
[Stablehlo] lowering chlo to stablehlo in torch-to-stablehlo pipeline (#3037)
as that stablehlo is better than chlo as the boundary between frontend
compiler and backend compiler.
2024-03-19 21:18:54 +08:00
Yuanqiang Liu 4282eb9e76
[Torch Dialect] support aten.fake_quantize_per_tensor_affine (#3014) 2024-03-15 08:53:29 +08:00
Nithin Meganathan 798bfd7dff
Adds accumulator types in TorchToLinalg for `AtenMmOp` and `AtenConvolutionOp` (#3027) 2024-03-14 16:40:40 -07:00
Yuanqiang Liu 870e63bc3c
[Torch Dialect] support decomposition of aten.linspace (#3006) 2024-03-14 08:28:33 +08:00
Yuanqiang Liu 43c6996a31
[Torch Dialect] add folder for aten.ceil and unify patterns of ceil, … (#3010)
…floor, round
2024-03-14 07:41:58 +08:00
ptrifunovic98 524ff99216
Implement lowering of torch.aten.linalg_cross (#2986)
Closes
[nod-ai/SHARK-Turbine#497](https://github.com/nod-ai/SHARK-Turbine/issues/497)
2024-03-13 12:17:22 -07:00
Nithin Meganathan 5ecc1d5c0d
Align softmax accumulation types with Torch's CUDA implementation (#2996) 2024-03-12 15:07:45 -07:00
Rob Suderman e78c99e74e
[torch] Update folders for splat operators (#3012)
Splat operators required the output is 1-D. This was not a required
restriction and was loosened to 2d.
2024-03-11 16:45:49 -04:00
Yuanqiang Liu 229ca3a9e1
[Torch Dialect] emit aten::mul and add folder (#3007) 2024-03-11 19:59:34 +08:00
Rob Suderman 0723584936
[torch] Add folder for torch.aten.*.Scalar comparisons (#3000)
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
2024-03-08 13:44:00 -08:00
Rob Suderman a78659742a
[onnx] Migrate `onnx.ReduceMax` to match `onnx.ReduceMin` (#2981)
This mostly copy-pastes the reduce minimum implementation to reduce max
to improve test coverage. We also improve the aten lowering for min/max
dim for unsigned types.
2024-03-06 16:48:21 -08:00
Rob Suderman 06292d9429
[torch] Rework `aten.repeat` to use flatten and unsqueeze (#2984)
Current implementation depends on using `aten.view` which has issues
inferring tensor collapse/expand operations during the lowering to
`linalg`. Using flatten and unsqueeze better infers what the later
reshape behavior.
2024-03-06 10:19:18 -08:00
Ze Zhang aa7c9a9653
e2e support aten.linalg_norm to aten.linalg_vector_norm (#2953)
Add e2d support for `aten.linalg_norm` by decompose it to
`aten.linalg_vector_norm`.

Lowering to `aten.linalg_matrix_norm` is still unsupported.

To Test: 

`python -m e2e_testing.main -v`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-03-05 16:31:01 -08:00
Rob Suderman bc0527676b
[torch] Add support for `torch.split_with_sizes` via decompose (#2979)
Convert to individiual slices and tuple together as a list.

---------

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
2024-03-05 15:01:21 -08:00
Rob Suderman a86e89ecb5
[torch] Additional folders for shape computations (#2972)
A handful of operations are commonly used in shape calculations (slice,
concat, broadcast). Added these additional folders to better propagate
simple shape computations.
2024-03-04 11:46:49 -08:00
Rob Suderman 19d4888278
[torch] Make torch.aten.unflatten lower directly to linalg (#2971)
Existing lowering via aten.view does not work as well for dynamic shapes
as the lowering to tensor.expand must re-infer dynamic shape matching.
Better to directly lower.
2024-03-04 10:17:42 -08:00
Rob Suderman 61f0a5facf
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
2024-03-01 21:41:12 -08:00
mmakevic 76b81e0ccd
Implement lowering of torch.aten.fmod.Tensor (#2767)
Closing https://github.com/nod-ai/SHARK-Turbine/issues/351
2024-02-29 11:22:03 +05:30
Rob Suderman e48fe45886
[onnx] Import `onnx` import to pass remaining tests (#2951)
Finish supporting importing the vast majority of `onnx` operations. This
includes:
- region support
- region value inherentance
- `torch.string` support
- `torch.list` support
- `torch.optional` support
2024-02-28 12:18:02 -08:00
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Rob Suderman 73b6df9007
[torch] Fix DecomposeAtenInstanceNorm decomposition (#2960)
The decomposition only suports a NCHW lowering however the operation can
support arbitrary spatial dimensions. Updated the lowering to better
support spatial dimensions.
2024-02-28 10:27:19 -08:00
Rob Suderman 4a7a7d76f8
[onnx] Fix ReduceMean lowering to torch (#2956)
Torch lowering only supported the most recent version. Refactored the
lowering so more easily handle default values and optional operands /
attributes.
2024-02-27 22:48:07 -08:00
Rob Suderman e30a083aff
[torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940)
We collapsed and broadcasted scatter indices to a single element
version. We should instead upport `tm_tensor.scatter`s support for
multiple indices and the implicitly broadcasted behavior. This avoids
the serialization and materializing a needlessly large indices tensor.
2024-02-27 11:46:57 -08:00
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
ptrifunovic98 c5a1da1910
Implement lowering of torch.aten.norm.Scalar (#2899)
Closes
[nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365)
2024-02-26 08:46:56 -08:00
Andreas Falkenberg 55dc8deb92
[torch] GridSample TorchToLinalg lowering (#2883)
Lowers `torch.grid_sample` to the equilvalent `linalg` representation.
2024-02-23 09:14:38 -08:00
Rob Suderman df2aa1a369
[torch] Fixed edge conditions for strided slicing (#2929)
Strided slicing can occur with a negative stride. In these cases we need
to bound end differently. This included removing a function that was
generating bad limits.
2024-02-21 21:28:44 -08:00
Stella Laurenzo 4446fa00d8
Migrate passes in TorchConversion to use FunctionOpInterface. (#2935)
This enables better re-use in downstreams which use different func
implementations and should have no impact on those that don't except in
opt pipelines if using the old form. With interfaces, explicit pipelines
via `--pass-pipeline=` must be used.
2024-02-20 08:54:02 -08:00
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman e80054a3cc
[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)
Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
2024-02-19 10:28:23 -08:00
aldesilv d29157b33f
OnnxToTorch support for onnx.InstanceNormalization op (#2710)
https://github.com/nod-ai/SHARK-Turbine/issues/327
2024-02-19 19:53:48 +05:30
Ze Zhang f3b38e5d12
DecomposeComplexOps: update parseEquation to skip space char for AtenEinsumOp op (#2910)
Just a minor update to skip the space char if included in the equation
string

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-02-14 18:18:11 -08:00
Vivek Khandelwal d6d1a173dc
[MLIR][Torch] Add OnnxToTorch and TorchToLinalg support for trig ops (#2903)
This commit adds the OnnxToTorch lowering for cosh, acosh, asin, asinh,
and atanh op.
This commit also adds the TorchToLinalg lowering for acosh, asin, asinh,
and atanh op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-14 11:58:09 +05:30
Rob Suderman e9cdd6cbc5
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We
map these back to generics so they appropriately match to the high
performance versions. This is done for the attention operation.
2024-02-13 21:18:01 -08:00
Scott Todd d6e1d836ca
Drop torch attributes at the end of backend conversion. (#2876)
Fixes https://github.com/llvm/torch-mlir/issues/2866

Some backends / downstream projects expect that a "fully converted"
program has no remaining ops or attributes from the original dialect(s).
2024-02-13 14:32:02 -08:00
Rob Suderman c0f139be0f
[torch] Add `torch.aten.eq.Tensor` comparison folder (#2889)
Added a folded for a equals operator. This allows an equivalent
comparison folder, primarily for when shape computations occur small
size tensor.
2024-02-09 15:02:20 -08:00
Rob Suderman 7d33ba69ac
[torch] Folder for torch.aten.select.int for splat cases (#2890)
If the input or result is a splat value we can just constant fold the
result. This is common for shape computations and can help with shape
inference.
2024-02-09 14:02:54 -08:00
Franz Haniel 4cc62aeb24
Implement trace (#2790)
The lowering decomposes AtenTraceOp into an AtenDiagonalOp followed by
AtenSumOp.

The progress is tracked in
https://github.com/nod-ai/SHARK-Turbine/issues/333.

---------

Co-authored-by: Franz Haniel <franz.haniel@amd.com>
2024-02-09 08:00:24 -08:00
Rob Suderman a8aad2a5ab
[torch] Add `torch.aten.where.*` folders (#2886)
Where operation can be statically computed when involving splats of
known value. Added handling these cases with multiple tests.
2024-02-07 19:43:31 -05:00
Dave Liddell 23647ab2d1
[torhc] aten.index_select folder (#2871)
Folds aten::index_select ops under the following conditions:

1. If the input and output are the same shape, the indexing operation is
a NOP, so just return the input.
2. If the input has shape <1x1x...xNx...x1> (all 1's except for one
dim), and the output shape is <1x1x...x1> (all 1's), then there is a
single index, so extract the single element value and return a tensor
with that value.

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-02-07 16:17:15 -08:00
mmakevic 32dbf99ce2
Implement lowering of torch.aten.all.dim (#2873)
Lowering of torch.aten.all.dim to linalg.

Per PyTorch documentation:

> This function matches the behaviour of NumPy in returning output of
dtype bool for all supported dtypes except uint8. For uint8 the dtype of
output is uint8 itself.

Since there is no support for ui8 in torch-mlir currently
(https://github.com/llvm/torch-mlir/pull/1384#issuecomment-1260011334)
implementation returns failure for that case.
2024-02-07 12:34:52 -08:00
Xida Ren (Cedar) fc04bc7ee9
[torch] AtenSliceOp folder that produces splat results (#2869)
Includes `slice` folder and lit tests

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-07 19:00:46 +00:00
Xida Ren (Cedar) cc06391630
AtenSortOp Folder (#2864)
A chunk off

https://github.com/llvm/torch-mlir/pull/2856
https://github.com/llvm/torch-mlir/pull/2860

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
2024-02-06 21:12:12 +00:00
Dave Liddell 1cb14f6879
Rob's atenTensor folder (#2867)
If a tensor is initialized by a list with a single constant integer,
this folder turns it into a torch.vtensor.literal

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-02-05 17:10:42 -08:00
Rob Suderman e3faef5224
[onnx] Convert `onnx.QLinearConv` to `torch` (#2851)
Leaning on the QDQ functionality in torch we can support the QLinearConv
operation by piggybacking through `torch.Convolution`. This includes
some changes such as allowing the `onnx` rewriter to run recursively.
Doing so allows `QLinearConv` to decopmose to `onnx.Convolution` which
is then lowered to `torch`.
2024-02-05 16:09:41 -08:00
Xida Ren (Cedar) 24b8c8672a
[torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849)
So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor

---------

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-02 10:46:33 -08:00
Rob Suderman 0114a570e3
[torch] Support lowering `torch.item` to `tensor.extract` (#2835)
Extracting scalar values from tensors can be implemented via a lowering
to tensor.extract.
2024-01-31 15:09:12 -08:00
Ilija Kalinić 54ef18c556
Implement lowering of torch.aten.lerp.Scalar (#2773)
Closes nod-ai/SHARK-Turbine#356
2024-01-31 09:39:38 -08:00
Yuanqiang Liu d778950f45
[Torch Dialect] add fold pattern for aten.clone (#2804) 2024-01-31 09:43:21 +08:00