Commit Graph

295 Commits (45c85c3b34629ee75d6b3a6fd9447894ce7a8ce3)

Author SHA1 Message Date
zjgarvey 0fb8b017d8
Adds misc fixes for some padding related issues (#3528)
This patch adds a few misc pad op related changes:

1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4

Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
2024-07-11 20:01:45 -05:00
Ze Zhang d466d5b809
Register fake_quantize related ops (#3522)
Register `aten.fake_quantize_per_channel_affine` and
`aten.fake_quantize_per_tensor_affine.tensor_qparams` ops

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-07-05 11:02:03 -07:00
Christopher McGirr 7e6d76e997
[Torch] Fix torch.constant.int operation parsing (#3476)
Due to the custom operation parser, the print and parser were expecting
two different forms.

One having the dictionary before the value and the other after.
Following the format of the other constants ops, the constant.int will
follow the `value attr-dict` format. Updated the parser accordingly.
2024-06-28 16:06:52 +02:00
Sambhav Jain 09f502667b
`AtenTensorOp::fold` should not fold when result type is not fully specified (#3494)
In one of our downstreams, we encountered an internal assertion failure
in an intermediate pass from `AtenTensorOp::fold` invocation:
```
external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed.
```

for this snippet in the IR:
```
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>}
...
    %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list<int>
    %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
```

Turns out this was
[fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719)
eventually (and we were on an old hash of torch-mlir). This PR submits
just the lit test for test coverage on that specific change:
```c++
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
  auto resultTy = dyn_cast<ValueTensorType>(getType());
  // lit test this
  if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
    return nullptr;
  ...
```
2024-06-24 15:22:50 -07:00
Peiming Liu ba16bad8c7
[torch-mlir] bump stablehlo/llvm version (#3471)
Update to llvm/llvm-project@5207632f86
Update to openxla/stablehlo@d41390c3a7
2024-06-18 16:59:53 -07:00
zjgarvey de28c8540b
[ONNX] add int16 quantization support (#3446)
There is currently no int16 quantization support in torch. This patch
adds a new mlir type to correspond to the missing "torch.qint16" type,
and enables lowering of quantization-related onnx ops using int16 types.

In follow-up patches, custom quantization logic for ops like
aten.matmul/aten.mm/aten.convolution may need to be revisited to allow
support for qint16. The passes in FuseQuantizedOps.cpp may also need
slight modifications.
2024-06-12 10:37:22 +05:30
Sambhav Jain d0a818a03e
Representing Symbolic Shape Expressions in Torch Dialect (#3372)
Torch Dialect with symbolic shape expressions:
```ll
module {                                                                                                                                                                                                     
  func.func @main(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {                                                                                   
    %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int                                                                                                                                    
    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int                                                                                                                                   
    %2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int                                                                                                                                    
    
    torch.bind_symbolic_shape %arg0, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    torch.bind_symbolic_shape %arg1, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                          
    
    %3 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                                  
    torch.bind_symbolic_shape %3, [%0, %1], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %4 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>                                                                                                               
    torch.bind_symbolic_shape %4, [%0, %2], #affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>                                                                                             
    
    %5 = torch.prim.ListConstruct %3, %3, %4 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>                                               
    %int1 = torch.constant.int 1                                                                                                                                                                             
    %6 = torch.aten.cat %5, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>                                                                                                          
    torch.bind_symbolic_shape %6, [%0, %1, %2], #affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32>                                                                            
    
    return %6 : !torch.vtensor<[?,?,3],f32>                                                                                                                                                                  
  }                                                                                                                                                                                                          
}              
```

For reference, this is the TorchDynamo exported program with symbolic
shape expressions that the above Torch dialect program is imported from:
```py
ExportedProgram:                                                                                                                                                                                             
    class GraphModule(torch.nn.Module):                                                                                                                                                                      
        def forward(self, x: "f32[s0, s1, 3]", y: "f32[s0, s3, 3]"):                                                                                                                                         
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:31 in forward, code: a = torch.tanh(x)                                        
            tanh: "f32[s0, s1, 3]" = torch.ops.aten.tanh.default(x);  x = None                                                                                                                               
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:32 in forward, code: b = torch.sigmoid(y)                                     
            sigmoid: "f32[s0, s3, 3]" = torch.ops.aten.sigmoid.default(y);  y = None                                                                                                                         
                                                                                                                                                                                                             
            # File: /home/sambhav.jain/workspaces/cruise/src/3p/torch-mlir/test/python/fx_importer/symbolic_shape_expr_test.py:33 in forward, code: return torch.cat((a, a, b), dim=1)                       
            cat: "f32[s0, 2*s1 + s3, 3]" = torch.ops.aten.cat.default([tanh, tanh, sigmoid], 1);  tanh = sigmoid = None                                                                                      
            return (cat,)                                                                                                                                                                                    
                                                                                                                                                                                                             
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])                                               
Range constraints: {s0: ValueRanges(lower=5, upper=10, is_bool=False), s1: ValueRanges(lower=0, upper=100, is_bool=False), s3: ValueRanges(lower=0, upper=50, is_bool=False)} 
```

Huge credit to @stellaraccident for the inputs that helped evaluate the
various design options and arrive at the representation of choice.


- [x] Op definitions for symbolic_int and bind_symbolic_shape ops
- [x] fx_importer updates to import range constraints + create
symbolic_int ops
- [x] fx_importer changes for AffineMapAttr building + adding
bind_symbolic_shape ops
- [x] custom printer/parser for inlined AffineMap expressions in mlir
assembly
- [x] Dialect lit test
- [x] fx_importer python lit tests
- [ ] Cleanup pass to remove these ops (can add in a follow-on)
2024-06-07 04:04:03 -07:00
Angel Zhang 2e194e13d6
[Torch] Fix bugs for `Torch::AtenOneHotOp` (#3350)
This PR fixes the bugs for `Torch::AtenOneHotOp` by:

1) Using `Torch::kUnknownSize` as the default value for `numClasses` in
   the pattern matching stage in `DecomposeAtenOneHotOp`
2) Adding `AtenIntScalarOp` to the patterns in `TorchToArith`
3) Handling both `int` and `float` types for `off` and `on` values in
`TorchOnnxToTorch` conversion

It also includes:

1) A new test in `TorchToArith/basic.mlir`, for `torch.aten.Int.Scalar`,
and
2) A new test in `decompose-complex-ops.mlir`, for `torch.aten.one_hot`

**Dependencies**

This PR is dependent on #3334.
2024-05-22 17:19:08 +00:00
Aaron St George ba32b9cee7
Don't fold `aten.clone` if result isn't same type as input (#3347)
Similar to https://github.com/llvm/torch-mlir/pull/2824, we were seeing
some assertion failures after the addition checks around folders were
tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 .
This PR essentially moves the logic that used to be applied at the LLVM
level into the folder, which seems to be the suggested fix.
2024-05-16 00:07:45 +08:00
zjgarvey 75d1d72059
Generalize Operand Quantization in FuseQuantizeOps (#3327)
This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
2024-05-12 20:49:59 -07:00
Vivek Khandelwal e60160d793
Revert "Decompose AtenNonzeroOp" (#3289)
Reverts llvm/torch-mlir#3281
2024-05-06 09:52:04 -07:00
Xida Ren (Cedar) 1af00e6040
Decompose AtenNonzeroOp (#3281)
This fixes some onnx lit tests not lowering to linalg in
https://github.com/nod-ai/SHARK-Turbine/issues/450
2024-05-05 21:59:25 +08:00
Ze Zhang 11cd7cd9e7
Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272)
While playing with TorchDynamo on ResNet18. I notice following issues:

- `prims.convert_element_type` can’t be canonicalized even if the input
and the output share the same type

- `aten.max_pool2d_with_indices` is always used instead of
`aten.max_pool2d`, even if the second returned output (indices) has no
user

This PR fixes above issues by adding a folder to the
PrimsConvertElementTypeOp and a canonicalizer to the
AtenMaxPool2dWithIndicesOp


Lit test:

`cmake --build build --target check-torch-mlir-all`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-05-02 00:03:41 -07:00
Yuanqiang Liu aed2cf3351
[Torch] emit aten.__contains__.str_list and add folder (#3249) 2024-04-29 10:51:17 +08:00
Stella Laurenzo 5d4b803914 [NFC reformat] Run pre-commit on all files and format misc.
This is part 1 of ~3, formatting all miscellaneous text files and CPP files matched by a first run of pre-commit. These tend to be low change-traffic and are likely not disruptive.

Subsequent patches will format Python files and remaining CPP files.
2024-04-27 14:08:09 -07:00
Yuanqiang Liu f173a06fa7
[Torch] emit aten.ne.str and add folder (#3242) 2024-04-28 00:58:50 +08:00
Yuanqiang Liu 634a796933
[Torch] fold aten.log (#3223) 2024-04-26 10:10:02 +08:00
Yuanqiang Liu b0ba3def93
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231) 2024-04-26 02:36:13 +08:00
Yuanqiang Liu fab2696489
[Torch] support aten.trunc (#3219)
decompose `trunc(x)` to `sign(x) * floor(abs(x))`
2024-04-24 14:32:33 +08:00
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
Xinyu Yang 6524838bcb
[Torch] Add general AdaptiveAvgPool2dOp decompose support (#3111)
Previously, it could only handle the situations where outputsize == (1,
1) or outputsize == (input_H, input_W). Now it supports all situations
where input_H % output_H== 0 && input_W % output_W == 0
2024-04-11 17:02:59 +08:00
Yuanqiang Liu 0a581a97a7
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
2024-03-27 09:51:58 +08:00
Rob Suderman 14b548f968
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
2024-03-26 12:41:40 -07:00
Rob Suderman 0723584936
[torch] Add folder for torch.aten.*.Scalar comparisons (#3000)
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
2024-03-08 13:44:00 -08:00
Rob Suderman a86e89ecb5
[torch] Additional folders for shape computations (#2972)
A handful of operations are commonly used in shape calculations (slice,
concat, broadcast). Added these additional folders to better propagate
simple shape computations.
2024-03-04 11:46:49 -08:00
Rob Suderman 61f0a5facf
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
2024-03-01 21:41:12 -08:00
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman e80054a3cc
[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)
Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
2024-02-19 10:28:23 -08:00
Rob Suderman e9cdd6cbc5
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We
map these back to generics so they appropriately match to the high
performance versions. This is done for the attention operation.
2024-02-13 21:18:01 -08:00
Rob Suderman c0f139be0f
[torch] Add `torch.aten.eq.Tensor` comparison folder (#2889)
Added a folded for a equals operator. This allows an equivalent
comparison folder, primarily for when shape computations occur small
size tensor.
2024-02-09 15:02:20 -08:00
Rob Suderman 7d33ba69ac
[torch] Folder for torch.aten.select.int for splat cases (#2890)
If the input or result is a splat value we can just constant fold the
result. This is common for shape computations and can help with shape
inference.
2024-02-09 14:02:54 -08:00
Dave Liddell 23647ab2d1
[torhc] aten.index_select folder (#2871)
Folds aten::index_select ops under the following conditions:

1. If the input and output are the same shape, the indexing operation is
a NOP, so just return the input.
2. If the input has shape <1x1x...xNx...x1> (all 1's except for one
dim), and the output shape is <1x1x...x1> (all 1's), then there is a
single index, so extract the single element value and return a tensor
with that value.

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-02-07 16:17:15 -08:00
Xida Ren (Cedar) fc04bc7ee9
[torch] AtenSliceOp folder that produces splat results (#2869)
Includes `slice` folder and lit tests

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-07 19:00:46 +00:00
Xida Ren (Cedar) cc06391630
AtenSortOp Folder (#2864)
A chunk off

https://github.com/llvm/torch-mlir/pull/2856
https://github.com/llvm/torch-mlir/pull/2860

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
2024-02-06 21:12:12 +00:00
Dave Liddell 1cb14f6879
Rob's atenTensor folder (#2867)
If a tensor is initialized by a list with a single constant integer,
this folder turns it into a torch.vtensor.literal

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
2024-02-05 17:10:42 -08:00
Rob Suderman e3faef5224
[onnx] Convert `onnx.QLinearConv` to `torch` (#2851)
Leaning on the QDQ functionality in torch we can support the QLinearConv
operation by piggybacking through `torch.Convolution`. This includes
some changes such as allowing the `onnx` rewriter to run recursively.
Doing so allows `QLinearConv` to decopmose to `onnx.Convolution` which
is then lowered to `torch`.
2024-02-05 16:09:41 -08:00
Xida Ren (Cedar) 24b8c8672a
[torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849)
So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor

---------

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-02 10:46:33 -08:00
Rob Suderman 25a5a22cbd
[torch] Support `torch.convolution` quantized lowering to `linalg` (#2811)
Linalg has quantized specific operations. We can lower to these
operations when there is a known zeropoint and scale operations. This
allows the `convolution` to occur with lower bitwidth's, improving the
overall performance.
2024-01-30 13:46:47 -08:00
Aaron St George 4c557847bd
Don't fold `aten.detach` if result isn't same type as input. (#2824)
We were seeing some assertion failures after some checks around folders
were tightened up in LLVM:
https://github.com/llvm/llvm-project/pull/75887 . This PR essentially
moves the logic that used to be applied at the LLVM level into the
folder, which seems to be the suggested fix.

I'm not sure if the IR that caused issues for us _should_ be valid?
```
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
```
A better fix might be to create a verifier ensuring the result of
`aten.detach` has the same type as its operand.

---------

Co-authored-by: aaron-stgeorge <aaron.stgeorge@getcruise.com>
2024-01-30 09:45:51 -08:00
Aart Bik fe836ceebf
[torch-mlir][test] cleanup trailing whitespace in mlir files (#2806) 2024-01-25 14:24:13 -08:00
Aart Bik e824fbc65c
[torch-mlir][torch] add encoding field to torch type (#2799)
This adds an encoding field to the torch type, using the interfaces for
printing, parsing, and verification. Note that although this change
prepares adding sparsity to the torch type (as illustrated by the round
trip and invalid tests), nothing in this change depends on the actual
contents of the encoding field!
2024-01-25 10:04:04 -08:00
Rob Suderman f6f890520b
[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing
dequantization into dense operations. As a validation we compare
to the dequant+mm torch implementation.
2024-01-24 14:02:50 -08:00
Han-Chung Wang 10acea71be
Bump LLVM to llvm/llvm-project@0cb024b (#2753)
- Add fixes for
af78e5daf0
- Add fixes for
bb6d5c2200
2024-01-15 07:12:12 -08:00
Zhekun(Josh) Zhang d67afa9e95
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543) 2023-11-21 13:26:17 +08:00
Stella Laurenzo 5eae0adff1
Breakup python pytorch deps (#2582)
This lifts the core of the jit_ir_importer and ltc out of the pt1
project, making them peers to it. As a side-effect of this layering, now
the "MLIR bits" (dialects, etc) are not commingled with the various
parts of the pt1 project, allowing pt1 and ltc to overlay cleanly onto a
more fundamental "just MLIR" Python core. Prior to this, the Python
namespace was polluted to the point that this could not happen.

That "just MLIR" Python core will be introduced in a followup, which
will create the space to upstream the FX and ONNX pure Python importers.

This primary non-NFC change to the API is:

* `torch_mlir.dialects.torch.importer.jit_ir` ->
`torch_mlir.jit_ir_importer`.

The rest is source code layering so that we can make the pt1 project
optional without losing the other features.

Progress on #2546.
2023-11-19 12:10:19 -08:00
James Newling dad1f012f6
Add verification for torch permute op (#2551)
- adds support for an optional verifier to the generated torch op
tablegen (GeneratedTorchOps.td)
- uses the above to add a verifier for the torch permute op. 

Motivation: I hit an unclear error from linalg while developing a
decomposition pass for pixel_shuffle. The error would have been clearer
if the problem had been detected earlier in the invalid aten.permute op.

Testing: new tests added. To run added tests, from the base directory
run

```
 ./build/bin/llvm-lit  test/Dialect/Torch/invalid.mlir
 ```
2023-11-15 11:47:54 -08:00
Yuanqiang Liu 3ab790c50a
[Torch Dialect] add canonicalize for aten.numel (#2562) 2023-11-11 12:16:53 +08:00
Zhekun(Josh) Zhang 88d4c475d3
[Torch] Fix mixP case for non value semantic ops (#2540)
NonValueSemantic Ops like Add_, div_, etc. expect result DType to be the
same as the first input. However, current implementation would result in
wrong result type for case like:

```python
a = torch.randn(3, 3).half() # float16
b = torch.randn(3, 3) # float32
a += b # i.e. torch.ops.aten.add_(a, b)
```
torch expects `a` to be float16, but dtype refinement would infer
float32 type, since it's replaced by `aten.add`.
2023-11-02 12:40:08 +08:00