Commit Graph

236 Commits (45c85c3b34629ee75d6b3a6fd9447894ce7a8ce3)

Author SHA1 Message Date
Xinyu Yang 790a697245
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL:   func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
  %none = torch.constant.none
  %false = torch.constant.bool false
  %float1.000000e01 = torch.constant.float 1.000000e+01
  %67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
  return %67 : !torch.vtensor<[],f32>
}

// CHECK-LABEL:   func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
  %none = torch.constant.none
  %false = torch.constant.bool false 
  %int45 = torch.constant.int 45
  %67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
  return %67 : !torch.vtensor<[],si32>
}

```
2024-04-19 22:17:06 +08:00
penguin_wwy 5a98c72c7f
[StableHLO] Fix aten.clamp.Tensor in FxImporter2StableHLO (#3190)
The FX importer will pass static shapes to the Torch dialect, so it
needs to generate a StableHLO that satisfies shape inference.
2024-04-19 17:08:29 +08:00
penguin_wwy 0a6073414d
[FxImporter] Add fx importer to stablehlo e2e test config (#3183) 2024-04-18 21:29:17 -07:00
penguin_wwy 6c4f7deebb
[stablehlo] add aten.clamp.Tensor op conversion support (#3185) 2024-04-19 10:55:27 +08:00
Rob Suderman be742a937d
[onnx] Update the failure triage for onnx (#3186)
Reclassifying what the source of failures are for various bugs so we can
reprioritize what failures are common.
2024-04-18 14:58:13 -07:00
Rob Suderman 0e77de996a
[torch] Add support for `torch.view` with dynamic shapes (#3164)
We can map to `tensor.reshape` for handling multiple output dynamic
shapes. Later we can perform a more complex analysis for indentifying
expand/collapse cases from the tensor.reshape.

Initially we planned to handle this identification at the `torch` level
however it will be easier to handle once converted to core
mlir-dialects.
2024-04-18 11:47:19 -07:00
Xinyu Yang d4313eed4a
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
penguin_wwy 3aa81f78d8
[FxImporter] Replace local_scalar_dense in fx_importer (#3180) 2024-04-17 22:45:47 +08:00
Xinyu Yang d2ba956e69
[Torch] Support Aten_CastLongOp. (#3160)
By canonicalize Aten_CastLongOp into AtenToDtypeOp
2024-04-17 21:58:32 +08:00
penguin_wwy e4b11a0ab4
[FxImporter] Fix fx importer test config and clean xfail set (#3176) 2024-04-16 22:36:07 -07:00
penguin_wwy 398aeeec87
[FxImporter] Fix kwarg operands in fx importer (#3166)
Remove the `kwarg_only` limitation, for example
```
torch.add(x, 3.0, alpha=2)
```
compiled to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int1
```
fix to
```
%0 = torch.aten.add.Scalar %arg0, %float3.000000e00, %int2
```
2024-04-16 13:17:05 -07:00
zjgarvey 7a1ad0d7c0
[TorchToLinalg] Adds Support for Remaining Quantized Matmul Cases (#3167)
The new cases added for quantized matmuls are:

1. vec-vec
2. vec-mat
3. mat-vec

each of which are now lowered to expand(s), quantized_matmul, and
collapse.
2024-04-16 09:28:28 -07:00
Vinayak Dev a0232e9ebd
[MLIR][TORCH] Add OnnxToTorch lowering for ReduceL1 Op (#3146)
Adds OnnxToTorch Lowering for the ReduceL1 op.
2024-04-16 12:24:46 +05:30
penguin_wwy 10b6062d41
[CI] Enable the tests for fx_importer in the CI (#3168)
Replace the torchdynamo e2e with the fx_importer e2e
2024-04-15 21:20:23 -07:00
Xinyu Yang ae4724763a
[Stablehlo] Enhance broadcast pattern in matmul Ops (#3161)
To pass test "MatmulStaticBroadcast_basic" in stablehlo:
```python
class MatmulStaticBroadcast(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([4, 1, 6, 7], torch.float32, True),
        ([8, 1, 5, 7, 6], torch.float32, True),
    ])
    def forward(self, lhs, rhs):
        return torch.matmul(lhs, rhs)


@register_test_case(module_factory=lambda: MatmulStaticBroadcast())
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
    module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
```
2024-04-16 10:10:36 +08:00
zjgarvey 5e564b5864
Adds Some Quantization Support for AtenMatmulOp (#3147)
1. onnx.MatMulInteger now converts to aten.matmul instead of aten.mm
2. aten.matmul, for ranks >=2, now allows quantized inputs and will
lower to linalg::quantized_matmul or linalg::quantized_batch_matmul.
3. added AtenMatmulOp to the FuseQuantizeOps rewrite patters
QuantizeOperands, QuantizeTransposedOperands, and QuantizeAccumulator
4. added several tests, including some to test AtenMmOp with varying
quantization signed-ness.
5. a quantized matmul mat-vec test is added to verify the failure to
lower to linalg; cleaned of out-of-date code related to common
torch-mlir lowering xfails.
6. in debugging a real model with quantized matmuls, I found a bug on
the scalarize-shapes pass which resulted from the aten.full op folder
returning an incompatible result type. This is fixed by the small change
here to
[lib/Dialect/Torch/IR/TorchOps.cpp](https://github.com/llvm/torch-mlir/compare/main...zjgarvey:torch-mlir:MatMulIntegerFix?expand=1#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4f).
2024-04-15 16:06:47 -07:00
IanWood1 5708ee7ec9
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (#3156)
- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
2024-04-15 13:45:10 -07:00
penguin_wwy 45eaeaaf36
[FxImporter] Add FxImporter config in e2e-test (#3151) 2024-04-12 16:07:56 -07:00
Xinan Jiang(姜曦楠) 71d90788d3
[MLIR][TORCH] Support parallel dimemsions expand/collapse (#3051)
This PR support `aten.view` with unique unknown dimension both in input
shape and output shape while the pass convert-torch-to-linalg that
lowing `aten.view` to `tensor.collapse_shape` or `tensor.expand_shape`.

Below is an example
```
func.func @test_reshape(%arg0: !torch.vtensor<[1,?,50,16],f32>) -> !torch.vtensor<[1,?,16],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %int1 = torch.constant.int 1
  %int-1 = torch.constant.int -1
  %int16 = torch.constant.int 16
  %0 = torch.prim.ListConstruct %int1, %int-1, %int16 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,50,16],f32>, !torch.list<int> -> !torch.vtensor<[1,?,16],f32>
  return %1 : !torch.vtensor<[1,?,16],f32>
}
```
2024-04-11 10:43:03 -07:00
Rob Suderman a1fe307a76
[torch] Support implicit batch for index_put (#3128)
If there is only a single value scattered there can be an implicit batch
dimension. This includes a check for the implicit batch dimension when
reshaping the update tensor. It includes an e2e test to verify
correctness.
2024-04-11 10:18:03 -07:00
Xinyu Yang 308c45e61a
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
    def forward(self, x):
        if len(x.shape) == 5:
            b0, t, c0, h0, w0 = x.shape
            b, c, h, w = torch.mul(b0, t), c0, h0, w0
        else:
            b1, c1, h1, w1 = x.shape
            b, c, h, w = b1, c1, h1, w1
        res = torch.reshape(x, [b, c, h, w])
        return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
Xinyu Yang 6524838bcb
[Torch] Add general AdaptiveAvgPool2dOp decompose support (#3111)
Previously, it could only handle the situations where outputsize == (1,
1) or outputsize == (input_H, input_W). Now it supports all situations
where input_H % output_H== 0 && input_W % output_W == 0
2024-04-11 17:02:59 +08:00
zjgarvey aa5e150313
Adds Some uint8 Quantization Fixes (#3122)
1. Changes the linalg lowering for dequantization ops to always sign
cast to float to prevent misrepresenting uint32 overflow on subtraction
with zero point.
2. Adds a basic quantized model test which only quantizes and
dequantizes and now passes with these changes in linalg and onnx
configs.
3. Changes the aten.mm lowering to allow mismatched quantized types. 
4. If a quantized matmul arg is uint8, we shift by 128 to faithfully
represent the quantization as a signed i8 quantization. This worked fine
in the AtenMmOp lowering, but I'd be happy to move it to a rewrite in
FuseQuantizedOps.cpp instead if that seems more appropriate.

With the changes 3 and 4, the QuantizedMLP_basic and
QuantizedSingleLayer_basic e2e tests now passes with the onnx config.
2024-04-10 12:36:58 -07:00
Xinyu Yang 42a16fa912
[Torch] Support Aten_CastFloatOp. (#3115)
By canonicalize Aten_CastFloatOp into AtenToDtypeOp
2024-04-09 11:06:53 +08:00
Xida Ren (Cedar) dd967eb199
[ONNX] Support onnx.LSTM (#2969)
This PR only performs a lit test. In lieu of an e2e test, https://github.com/nod-ai/SHARK-TestSuite/pull/142 makede sure that the lowering works & the numbers check out.

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-04-08 12:23:33 -07:00
Vivek Khandelwal 1d6e4c3d77
[MLIR][TORCH] Add OnnxToTorch lowering for Einsum op (#3117)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-08 22:38:01 +05:30
Xinyu Yang 84c24e5771
[Torch] Support Aten__And__ScalarOp (#3114) 2024-04-08 20:24:17 +08:00
Yuanqiang Liu 2c56ef9252
[Torch Dialect] canonicalize aten.sign to aten.sgn (#3112)
* `aten.sign` is a sub-set of `aten.sgn` (`aten.sgn` support complex
type).
2024-04-08 20:05:42 +08:00
Yuanqiang Liu 498ab997cd
[Stablehlo] lowering aten.log1p to stablehlo.log_plus_one (#3110) 2024-04-07 17:01:58 +08:00
Yuanqiang Liu 0a00f38a7e
[Stablehlo] add stablehlo-aggressive-simplification in e2e test (#3109)
* so that more stablehlo e2e testcases would pass.
2024-04-07 10:48:11 +08:00
Vivek Khandelwal af54d27820
[MLIR][TORCH] Fix Onnx.TopK lowering (#3103)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-03 22:12:48 +05:30
Vivek Khandelwal ce7d4f1660
[MLIR][TORCH] Fix Onnx.ReduceSum lowering for failing e2e tests (#3095)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-03 09:57:19 +05:30
Rob Suderman f97cd4893f
[torch] Improve shape inference for dynamic shapes (#3091)
Shapes can be processed as tensors to represent the set of dimensions.
As reshapes take a list of scalars this can result in a single dynamic
dimension blocking the adjacent static dimensions.

This pass attempts to de-couple tensor computations related to shapes
and propagate values to better support lowering scalar tensor
computations.
2024-04-02 16:19:57 -07:00
Yuanqiang Liu 6cbb2f7ae0
[Stablehlo] add stablehlo-canonicalize-dynamism when lowering (#3097)
so that many stablehlo e2e testcases could pass
2024-04-02 22:47:24 +08:00
Xinyu Yang ac1cd3d78a
[Torch] Support AtenDivTensorModeOp with static int input for linalg and stablehlo backend (#3088) 2024-04-02 17:28:53 +08:00
ptrifunovic98 1c8c47d483
Add complex support for aten.norm and similar operations (#3052)
Add support for complex-type input tensors for norm, vector norm, and
Frobenius norm operations.
2024-04-02 14:03:30 +05:30
Rob Suderman 0f5d5e9f4e
[stablehlo] Fix test stablehlo e2e test suite (#3093)
There is an issue with stablehlo's linalg compilation. Canonicalization
appears to cleanup the issues until we can determine what in
mlir/stablehlo is the source of the issue.
2024-04-02 12:40:00 +08:00
Rob Suderman ec4cb8be44
Bump LLVM to llvm/llvm-project@0030fc4ac7 (#3079)
Co-authored-by: Peiming Liu <peiming@google.com>
2024-04-01 16:34:59 -07:00
zjgarvey 532d297c46
[ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089)
See the related issues here:
[SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556)

1. Adds uint8 casting to onnx.Cast op
2. Fixes an issue with onnx.DequantizeLinear when the scale comes with
shape [1].
3. Adds support for unsigned types in an AtenItemOp folder
4. Adds a simpler quantized model for easier debugging
5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm]
patterns to [transpose -> quant -> mm].
6. Moved some xfails that are still not passing, but for different
reasons than onnx.cast failures.
2024-04-01 16:21:05 -07:00
Vivek Khandelwal 6844c84702
[MLIR][Torch] Fix OnnxToLinalg lowering for AvgPool op (#3076)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-04-01 22:14:14 +05:30
Xinyu Yang 40008b025a
[Torch] Support prelu decomposition (#3069) 2024-03-29 08:05:00 +08:00
zjgarvey c19fc9ba47
[ONNX] Fixes Issue with Dynamic Dims in GlobalAveragePool -> Torch Conversion (#3053)
Two e2e tests (AdaptiveAveragePool1/2dUnitOutputSizeDynamic) were
failing due to numerics. This was as a result of passing -1 as the
kernel size in the lowering for the corresponding onnx op
GlobalAveragePool.
2024-03-28 09:43:09 -07:00
Xinyu Yang e6e7689a24
[Torch] support decompose aten.einsum with ellipsis slicing (#3056) 2024-03-27 12:42:10 -07:00
Rob Suderman 14b548f968
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
2024-03-26 12:41:40 -07:00
Vivek Khandelwal 9ae33e482e
[MLIR][TORCH] Add OnnxToTorch lowering for ops (#3049)
This commit adds the OnnxToTorch lowering for the Mish, Softplus,
HardSwish, Trilu, ThresholdedRelu op

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-03-25 20:29:07 +05:30
schnkmwt 1fcbfa87ec
Implement linalg lowering of diag_embed torch op (#2885)
This PR adds lowering of diag_embed to linalg dilect.
Tracked in https://github.com/nod-ai/SHARK-Turbine/issues/288

---------

Co-authored-by: sachink <sachink@xilinx.com>
2024-03-22 16:32:50 -07:00
zjgarvey 99b3a5f117
Converts all Adaptive Pooling Ops to Linalg (#2808)
The previous conversions for AtenAdaptiveAvgPool1dOp and
AtenAdaptiveMaxPool2dOp are refactored into a general templated
conversion that works for all of the AtenAdaptive...PoolNdOp's.

New support is added for the following ops:

1. AtenAdaptiveMaxPool1d
2. AtenAdaptiveMaxPool3d
3. AtenAdaptiveAvgPool3d

Support is also provided for passing inputs without batch dimensions.
For example, applying adaptive_avg_pool2d to an input tensor of rank 3.

After [pytorch #118162](https://github.com/pytorch/pytorch/pull/118162)
gets down to torch-mlir, I'll add a test for AdaptiveMaxPool1d with
return_indices (which will pass with that upstream fix).

---------

Co-authored-by: James Newling <james.newling@gmail.com>
2024-03-22 11:05:20 -07:00
zjgarvey 6aa481c204
[ONNX] LogSoftmax to Torch (#3024)
This PR adds support for onnx.LogSoftmax both for old versions (<13,
with axis >=0), and new versions (13).
2024-03-22 11:01:39 -07:00
Rob Suderman 3a56714bff
[torch] Fix clamp ranges on quantize_per_tensor on unsigned (#3018)
SExtValue was used for `int` and `uint` clamp values. This caused the
result to always be outputed as `zero`.
2024-03-20 13:37:47 -07:00
Xida Ren (Cedar) cb5cb506df
Fix SCF Forloop fails to convert to linalg when a tensor argument is supplied to the loop block (#3040)
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-03-20 11:04:02 -07:00
zjgarvey 6ff71b40c8
[ONNX] onnx.DynamicQuantizeLinear to Torch (#3009)
This adds support for converting DynamicQuantizeLinear from torch-onnx
to torch.

I could not get an e2e test to pass, since there seems to be some issues
with uint8 casting somewhere lower in the pipeline. For example
compiling with IREE for llvm-cpu, I would get either the correct zero
point (if zp < 128) or the correct zero-point minus 256 (if zp >= 128).
The output tensor seems to always return a tensor of zeros, which also
occurs when running uint8 examples through QuantizeLinear.

Edit: the first problem can be resolved by casting the output back to
uint8 on output, the second problem is resolved with PR #3018
2024-03-20 10:58:25 -07:00
Abhishek-TyRnT df02692726
Dynamic size support for flatten (#3005)
Added support for dynamic shapes in `flattenusingints` op in tosa
dialect. Due to this some Argmax tests pass
This PR fixes this issue https://github.com/llvm/torch-mlir/issues/3004

The following tests pass after this PR
 ```
1. "ArgmaxIntModule_basic"
2. "ArgmaxIntModule_multiple_maxs"
3. "ArgmaxModule_basic"
```
2024-03-19 15:19:29 -07:00
Pavani Chowdary c51e2130f2
[onnx] support for lowering mod op from onnx to torch (#2859)
nod-ai/Shark-Turbine#267

---------

Authored-by: boddu.pavani@research.iiit.ac.in
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-03-18 17:54:37 +05:30
Xinan Jiang(姜曦楠) d8a52e82c2
[onnx] Fix onnx.cast cases between int32 and int64 (#2982)
2 modifications:
1. torch.int64 is enum 4 in TORCH_DTYPE_TO_INT
2. add int32 support
2024-03-15 17:14:09 +00:00
Yuanqiang Liu 4282eb9e76
[Torch Dialect] support aten.fake_quantize_per_tensor_affine (#3014) 2024-03-15 08:53:29 +08:00
Yuanqiang Liu 870e63bc3c
[Torch Dialect] support decomposition of aten.linspace (#3006) 2024-03-14 08:28:33 +08:00
ptrifunovic98 524ff99216
Implement lowering of torch.aten.linalg_cross (#2986)
Closes
[nod-ai/SHARK-Turbine#497](https://github.com/nod-ai/SHARK-Turbine/issues/497)
2024-03-13 12:17:22 -07:00
Yuanqiang Liu ad6159c7cb
[Stablehlo] lowering aten.round to stablehlo.round_nearest_even (#3011) 2024-03-12 08:58:20 +08:00
Devjiu 4b1e87ce67
[TorchDynamo] Enable Elemtwise ops for Scalar arg (#2744)
This commit provides dummy solution to support elmentwise operations
(mul, add) with scalar argument. ( op(Tensor, Scalar) )

It replaces `torch.aten.add.Tensor` with `torch.aten.add.Scalar`.
```
Unexpected outcome summary: (torchdynamo)

****** Unexpectedly Passed tests - 22 tests
    XPASS - "AddCDivModule_basic"
    XPASS - "BatchNorm1DModule_basic"
    XPASS - "BatchNorm1DStaticShapeModule_basic"
    XPASS - "BatchNorm1DWith2DInputModule_basic"
    XPASS - "BatchNorm2DModule_basic"
    XPASS - "BatchNorm3DModule_basic"
    XPASS - "ElementwiseAddScalarInt64Module_basic"
    XPASS - "ElementwiseAddScalarIntModule_basic"
    XPASS - "ElementwiseMulScalarModule_basic"
    XPASS - "ElementwiseMulScalarModule_float"
    XPASS - "ElementwiseMulScalarModule_int"
    XPASS - "GroupNormModule_basic"
    XPASS - "GroupNormNoWeightAndBiasModule_basic"
    XPASS - "MobilenetV3Module_basic"
    XPASS - "NativeBatchNorm1DModule_basic"
    XPASS - "NativeBatchNorm2DModule_basic"
    XPASS - "NativeBatchNorm3DModule_basic"
    XPASS - "NativeBatchNormNoneWeightModule_basic"
    XPASS - "NativeGroupNormBackwardModule_basic"
    XPASS - "NativeGroupNormModule_basic"
    XPASS - "ResNet18Module_basic"
    XPASS - "ResNet18StaticModule_basic"
```

And segfault for test
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic". Somehow this
change doesn't allow to use Tensors, that are not forward arguments, but
local variables of model.
e.g. `self.x = torch.tensor(..)`

See also: #2745

Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko@intel.com>
2024-03-11 12:22:05 -07:00
Rob Suderman 8fb28661f9
[onnx] Fix onnx.ReduceMean lowering (#3002)
Reduce mean lowerings did not succesfully lower to `linalg` via torched.
There were two separate paths that could be consolidated to a single
simpler pass. This resulted in a significant improvement in test
coverage.
2024-03-11 11:32:53 -07:00
Rob Suderman bd7f1baa42
[onnx] Fix expand operation for dynamic shape max (#3001)
If the broadcast shape is length-1 at a dim while `?` in the input dim
then we need to broadcast to the dynamic dim. This is equivalent to
taking a max of two dimensions.
2024-03-08 16:23:07 -08:00
Rob Suderman 0723584936
[torch] Add folder for torch.aten.*.Scalar comparisons (#3000)
This folds small version of the tensor-scalar comparison operators as
they are commonly used for shape computations. This includes le, lt, ge,
gt, eq, and ne.
2024-03-08 13:44:00 -08:00
Andreas Falkenberg 551a4e45f3
[onnx] Add support for `onnx.Gemm` with no bias (#2993)
Previous gemm version required a bias vector. 
This provides an alternate path to `Torch::AtenMm`
with no bias operation.
2024-03-07 15:58:38 -08:00
Rob Suderman 1964208d19
[onnx] Fix constant pad for dynamic shape (#2989)
The current padding operation was not functional for dynamic shapes.
Updated and enabled tests so that onnx.pad tests pass.

Work TBD for reflection padding.
2024-03-07 13:29:50 -08:00
Scott Todd 7b18646def
[onnx] Handle optional arguments in Clip op pattern. (#2976)
Spec: https://onnx.ai/onnx/operators/onnx__Clip.html
2024-03-07 17:25:14 +00:00
Vivek Khandelwal 6e84752c39
build: manually update PyTorch version (#2992)
Set PyTorch and TorchVision version to nightly release 2024-03-07.
This commit also removes the deprecated constraints API:
342e7929b8

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-03-07 21:42:38 +05:30
Rob Suderman a78659742a
[onnx] Migrate `onnx.ReduceMax` to match `onnx.ReduceMin` (#2981)
This mostly copy-pastes the reduce minimum implementation to reduce max
to improve test coverage. We also improve the aten lowering for min/max
dim for unsigned types.
2024-03-06 16:48:21 -08:00
Andreas Falkenberg ea76dd12ba
[onnx][torch] Gridsampler E2E test and corrections of gridsampler (#2987)
The addition of an e2e test is actually provided in the Shark-Testsuite.
This adds 2 test cases for the gridsampler e2e test. 
Also as intended there were some items found which needed correction, so
the Gridsampler op has also a change.
2024-03-06 10:56:58 -08:00
Rob Suderman 06292d9429
[torch] Rework `aten.repeat` to use flatten and unsqueeze (#2984)
Current implementation depends on using `aten.view` which has issues
inferring tensor collapse/expand operations during the lowering to
`linalg`. Using flatten and unsqueeze better infers what the later
reshape behavior.
2024-03-06 10:19:18 -08:00
Ze Zhang aa7c9a9653
e2e support aten.linalg_norm to aten.linalg_vector_norm (#2953)
Add e2d support for `aten.linalg_norm` by decompose it to
`aten.linalg_vector_norm`.

Lowering to `aten.linalg_matrix_norm` is still unsupported.

To Test: 

`python -m e2e_testing.main -v`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-03-05 16:31:01 -08:00
Rob Suderman bc0527676b
[torch] Add support for `torch.split_with_sizes` via decompose (#2979)
Convert to individiual slices and tuple together as a list.

---------

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
2024-03-05 15:01:21 -08:00
Chi_Liu 09875fabd1
[MLIR][ONNX] Add ONNX ReduceProd support (#2943)
Alternatives to https://github.com/llvm/torch-mlir/pull/2908

Fix https://github.com/nod-ai/SHARK-Turbine/issues/353
2024-03-04 11:07:03 -08:00
Rob Suderman d51e80b648
[onnx] Fix onnx.gather lowering for rank-0 indices (#2973)
We assumed rank was atleast 1 however it can be rank-0, generating an
illegal pair of flatten / unflatten operations. Corrected this.
2024-03-04 08:25:19 -08:00
Rob Suderman 61f0a5facf
[torch] Add an `aten.cat` length-0 canonicalization (#2966)
If an input is length-0 along the dimension of canonicalization we can
remove the tensor from the list
2024-03-01 21:41:12 -08:00
Rob Suderman d030bffc62
[torch] Support `aten.view` rank-0 collapse (#2965)
Collapsing to a rank-0 tensor using `aten.view` was currently bailing
out. Added the special case.
2024-03-01 12:31:07 -08:00
mmakevic 76b81e0ccd
Implement lowering of torch.aten.fmod.Tensor (#2767)
Closing https://github.com/nod-ai/SHARK-Turbine/issues/351
2024-02-29 11:22:03 +05:30
Rob Suderman ed6e75908b
Bump LLVM to llvm/llvm-project@e5ed7b6e2f (#2964) 2024-02-28 14:13:26 -08:00
Rob Suderman 6f3d62ab04
[torch] Fix folders and `cat` and `view` torch lowerings (#2963)
A bunch of small fixes are interlinked and trigger crashes if not
addressed as a group. This includes:

- aten view when expand from a rank-0 tensor
- slice folder with negative indices
- `aten._shape_as_tensor` folder on a rank-0 tensor
- `aten.cat` of a tensor with a length-0 tensor
2024-02-28 12:04:52 -08:00
Rob Suderman 08bc013fcd
[tosa] Fix TOSA batch matmul lowering to correct transpose ordering (#2959)
The corrective transpose at the end is computed incorrectly. Is it
actually computin the inverse transpose. Inverting the permutations
fixes the issue.
2024-02-28 09:46:58 -08:00
Rob Suderman 4a7a7d76f8
[onnx] Fix ReduceMean lowering to torch (#2956)
Torch lowering only supported the most recent version. Refactored the
lowering so more easily handle default values and optional operands /
attributes.
2024-02-27 22:48:07 -08:00
Abhishek-TyRnT d541779f37
Add support for torch arange float module (#2749)
Added Support for float dtype in in torch.arange in TOSA Dialect

This resolves the following issue :- 
https://github.com/llvm/torch-mlir/issues/2762

The following test cases are passing after this change

1. ArangeDtypeIntModule_basic
2. ArangeFloatModule_basic
3. ArangeNegativeStartFloatModule_basic
4. ArangeStartFloatModule_basic
5. ArangeStartNegativeStepFloatModule_basic
6. ArangeStartOutDtypeModule_basic
7. ArangeStartStepFloatModule_basic

---------

Co-authored-by: James Newling <james.newling@gmail.com>
2024-02-27 13:40:55 -08:00
Vivek Khandelwal d628b5fd06
[MLIR][TORCH] Add support for tanh approximation for Gelu op (#2941)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/461

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-02-27 19:26:01 +05:30
Vivek Khandelwal d81747eadb
[MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451,
https://github.com/nod-ai/SHARK-Turbine/issues/452
2024-02-27 11:02:05 +05:30
ptrifunovic98 c5a1da1910
Implement lowering of torch.aten.norm.Scalar (#2899)
Closes
[nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365)
2024-02-26 08:46:56 -08:00
Rob Suderman 53f6d06ab8
[onnx] Drop `ConstantOfShape` logic form importer, fix torch lowering (#2930)
There is no reason to treat `ConstantOfShape` as a specialized import
any as there exists a onnx-to-torch equivalent. Dropping the import
coding and adding support for resource conversion substantially
increases test coverage for dynamically shaped tests.
2024-02-21 21:34:43 -08:00
Rob Suderman df2aa1a369
[torch] Fixed edge conditions for strided slicing (#2929)
Strided slicing can occur with a negative stride. In these cases we need
to bound end differently. This included removing a function that was
generating bad limits.
2024-02-21 21:28:44 -08:00
Rob Suderman 13113df33e
[onnx] Enable crashing tests (#2928)
Crashing tests no longer crash, enable as either passing or xfail tests.

Co-authored-by: Xida Ren (Cedar) <cedar.ren@gmail.com>
2024-02-20 18:34:21 +00:00
Rob Suderman 13553d49c9
[onnx] Update the importer to create a `none` for missing operands (#2931)
Some operands are optional so we require a placeholder for missing
operands. We invent an `onnx.None` operation as our placeholder.
2024-02-20 09:30:30 -08:00
Rob Suderman 135c81a416
[torch] Add folder for `prim.NumToTensor.Scalar` (#2921)
Useful for `slice` lowerings that depend on tensors made form scalars.
2024-02-19 11:55:54 -08:00
Rob Suderman cea51897a5
[onnx] Simplify onnx.slice lowering (#2919)
Onnx slice lowering used arange needlessly instead of directly
constructing the constant dimension values. This makes lowerings to
linalg struggle as multiple folders are required to get what is a
constant index value.
2024-02-19 10:26:29 -08:00
aldesilv d29157b33f
OnnxToTorch support for onnx.InstanceNormalization op (#2710)
https://github.com/nod-ai/SHARK-Turbine/issues/327
2024-02-19 19:53:48 +05:30
Rob Suderman d65925a8b4
[onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914)
Sample compilation crashes due to sigmoid with integer inputs/outputs.
This fix avoids crashing but still experiences an error.
2024-02-16 13:35:25 -08:00
Rob Suderman 7a0d0e954b
[onnx] Fix onnx.gather lowering to use torch.aten.index_select (#2913)
Onnx's gather maps directly to `torch.aten.index_select`. We should just
use that path.
2024-02-16 16:05:44 -05:00
Rob Suderman 074f112d6a
[onnx] Add testing using the `onnx` compilation using torch tests (#2795)
We can route the torch tests via `onnx` using the `torch.onnx.export`
tooling. We can then reimport, lower to torch, and compile to linalg to
validate the onnx path is working correctly.

The current implementation exposes some failures in the `onnx` path so
we cannot enable the onnx test suite yet due to segmentation faults.
2024-02-15 10:17:13 -08:00
Yuanqiang Liu f3e8199a6d
[Stablehlo] add refbackend (#2712) 2024-02-16 01:08:48 +08:00
Rob Suderman e9cdd6cbc5
[torch] Fix tm_tensor.attention for end-to-end (#2907)
Some operations include a backend matcher for specialized operations. We
map these back to generics so they appropriately match to the high
performance versions. This is done for the attention operation.
2024-02-13 21:18:01 -08:00
Avinash Sharma 9659a436d1
Add lowering support for math::AbsIOp (#2875)
There is no lowering support for math::AbsIOp, so if the operand is an
integer type, it will fail to lower to math::AbsFOp since the op operand
#0 must be floating-point-like.
2024-02-08 14:53:40 -08:00
Aart Bik 44f8f89826
[torch-mlir][sparse] add sparsification to linalg reference backend (#2887)
This adds a few passes that will ensure linalg with sparse tensors are
properly lowered to loops and can run using the ExecutionEngine for
testing (a few details on parameter passing from PyTorch still TBD)

Test results:

$ ./tools/e2e_test.sh --config linalg

Summary:
    Passed: 1144
    Expectedly Failed: 8

$ python -m e2e_testing.main --config=torchdynamo -v

Summary:
    Passed: 960
    Expectedly Failed: 163

Filed issue:
https://github.com/pytorch/pytorch/issues/119407
2024-02-08 09:37:31 -08:00
Rob Suderman 041a54ae0c
[torch] Supporting `torch.aten.mul.float` lowering to `arith` (#2833)
Simple missing scalar operation for multiply floats was missing.
2024-02-05 16:23:04 -08:00
Xida Ren (Cedar) 24b8c8672a
[torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849)
So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor

---------

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
2024-02-02 10:46:33 -08:00
Rob Suderman 0114a570e3
[torch] Support lowering `torch.item` to `tensor.extract` (#2835)
Extracting scalar values from tensors can be implemented via a lowering
to tensor.extract.
2024-01-31 15:09:12 -08:00
Ilija Kalinić 54ef18c556
Implement lowering of torch.aten.lerp.Scalar (#2773)
Closes nod-ai/SHARK-Turbine#356
2024-01-31 09:39:38 -08:00
Yuanqiang Liu d778950f45
[Torch Dialect] add fold pattern for aten.clone (#2804) 2024-01-31 09:43:21 +08:00
Rob Suderman 25a5a22cbd
[torch] Support `torch.convolution` quantized lowering to `linalg` (#2811)
Linalg has quantized specific operations. We can lower to these
operations when there is a known zeropoint and scale operations. This
allows the `convolution` to occur with lower bitwidth's, improving the
overall performance.
2024-01-30 13:46:47 -08:00
Rob Suderman 2ef228328f
[torch] `torch.dequantize` for per channel tensors to` linalg` (#2769)
Support a lowering for dequantization for per channel tensors from
`torch` dialect to a linalg decomposition. Tested via a numerical
`torch` test.
2024-01-25 16:40:21 -08:00
Rob Suderman f6f890520b
[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing
dequantization into dense operations. As a validation we compare
to the dequant+mm torch implementation.
2024-01-24 14:02:50 -08:00
Xida Ren (Cedar) ccaac85788
implement aten.conv1d, aten.conv3d, and aten.conv_tbc (#2757)
convolution with [time,batch,channel] ordering, as opposed to the
default [batch, channel, time]. Currently implementing by transposing
the input and output, but may need to get its own implementation in the
future because this is supposed to be an op that gives a speedup. This
is used by fairseq
(https://github.com/facebookresearch/fairseq/issues/172).

(in case you were wondering like me, this is different from transposed
convolution. Transposed convolution has fractional strides).

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
Co-authored-by: Frederik Harwath <frederik.harwath@amd.com>
2024-01-23 21:30:03 -08:00
John Wu 704cfdaf08
Add aten.pool_max3d support to torch-to-linalg (#2735)
Added verification logic to the abstract_interpreter_lib_gen.py

Also made some unit tests

Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).

To achieve this, we would need to lower our code using the `linalg`
dialect.


Turns out the pooling code in `linalg` looks like this.

```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
                        %strides: memref<3xindex>, %dilations: memref<3xindex>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
    %C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
    %D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
    %H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
    %W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>

    %kernel_d = memref.load %K[%c0] : memref<3xindex>
    %kernel_h = memref.load %K[%c1] : memref<3xindex>
    %kernel_w = memref.load %K[2] : memref<3xindex>
    %stride_d = memref.load %strides[%c0] : memref<3xindex>
    %stride_h = memref.load %strides[%c1] : memref<3xindex>
    %stride_w = memref.load %strides[2] : memref<3xindex>
    %dilation_d = memref.load %dilations[%c0] : memref<3xindex>
    %dilation_h = memref.load %dilations[%c1] : memref<3xindex>
    %dilation_w = memref.load %dilations[2] : memref<3xindex>

    linalg.generic {
        indexing_maps = [
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>,  // Map for input tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>,                                              // Map for kernel tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)>                                            // Map for output tensor
        ],
        iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
        doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
    } ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
        ^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
            %max_val = arith.maxf %input_elem, %output_elem : f32
            linalg.yield %max_val : f32
    }
    return
}

```

This was implemented based on it's source code with the adjustments
mentioned above:

4ca1b5e094/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (L5647)

Issues related to this can be found here

https://github.com/nod-ai/SHARK-Turbine/issues/324
2024-01-19 21:09:46 +05:30
Ze Zhang 77a03f2069
torch-to-tosa lowering support for AtenLinalgVectorNormOp (#2734)
This PR add torch-to-tosa lowering support for AtenLinalgVectorNormOp

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-01-18 12:32:23 -08:00
Sungsoon Cho a8538e1e3f
Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737) 2024-01-15 22:49:29 -08:00
lonely eagle f85e5c932b
[Torch Dialect] support aten.isneginf, aten.isposinf, aten.nan_to_num (#2743) 2024-01-16 14:29:34 +08:00
Rob Suderman dc37616d67
[torch][quant] Support quantize and dequantize for torch (#2731)
Handle both `torch.dequantize` and `torch.quantize_per_tensor` including
the op based quantization parameter tracking. This includes adding
`qint32` to torch types as it was missing during the initial type
inclusion.

For testing we only have `torch.int8` and `torch.float` types on
function boundaries as the `qint8` types require passing the scale
and zero point quantization information which is not supported yet.
2024-01-12 19:11:14 -08:00
Ze Zhang 670a99ae19
Handle torch.none type in tosa.clamp op (#2739)
This PR updates the torch-to-tosa conversion with following changes:

- Support torch.none as min/max input argument for tosa.clamp op
- Support negative value as start index for tosa.slice op
- Add tosa.logical_or lowering support

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
2024-01-11 10:36:48 -08:00
Ilija Kalinić e1a86e480a
Implement lowering of torch.aten.logit (#2697)
Closes nod-ai/SHARK-Turbine#290
2024-01-11 20:25:42 +05:30
zjgarvey 07d0645f64
[RFC] general support for Adaptive Pooling Ops (#2661)
Adaptive pooling ops can only be decomposed into their non-adaptive
counterparts in trivial cases.

For example, the current decomposition for AtenAdaptiveAvgPool1dOp in
DecomposeComplexOps.cpp supports outSize = inSize (i.e., do literally
nothing), and outSize = 1 (i.e., do a batched average).

The reason adaptive pooling ops are difficult to lower to linalg is that
they are not constantly strided. They are computed by taking an input
tensor of shape (N, C, Hin), and an output size Hout, and computing the
output tensor at position (n,c, h) in the following way:

1. compute st(h) = (h*Hin)//Hout
2. compute en(h) = 1 + ((h+1)*Hin -1)//Hout
3. apply a computation (max or avg) to the slice: INPUT[n, c,
st(h):en(h)]

The provided sample implementation (for ConvertAtenAdaptiveAvgPool1dOp)
uses tensor.extract to access the input tensor inside the payload of a
linalg generic op. This is likely an unattractive use of linalg generic
ops, which is why I am asking for some more targeted feedback on the
validity of this approach before attempting to support the many other
adaptive pooling ops.

Specifically:

- Is the performance of this implementation bad enough to warrant
targeting different dialects entirely? e.g. TMtensor/linalg ext/ etc.
- If the provided implementation is of acceptable performance to the
community, then is it permissable to remove the Adaptive pooling
decompositions from DecomposeComplexOps.cpp? Based on the current
structure of the -torch-decompose-complex-ops pass, it does not seem
possible to only decompose the adaptive ops in special cases (it seems
to get stuck in an infinite loop on a match failure). I would be happy
to instead incorporate the case logic into the conversion directly, and
remove the decompositions once they are rendered completely obsolete.

As long as this approach is acceptable, I can clean up the
implementation with some helper functions, and quickly add support for
each of the remaining Adaptive pooling ops.
2024-01-09 11:14:10 -08:00
Vivek Khandelwal 690827fe52 build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2024-01-02.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2024-01-03 11:47:12 +05:30
Xida Ren (Cedar) 6660a26594
lower torch.aten.isinf to linalg (#2638)
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
2023-12-28 17:20:32 -08:00
Sungsoon Cho 8e389ff2ff
Implement lowering of torch.aten.exponential (#2680)
https://github.com/llvm/torch-mlir/issues/2646

Decompose aten.exponential() into: -exp(1-x)/lambda
2023-12-27 20:33:18 -08:00
JianzheXiao 6ddeb1a6ef
[torch] Add support for aten.selu (#2640)
Add `aten.selu` operation to `torch` dialect.
2023-12-13 20:28:08 -08:00
JianzheXiao 7cf52ae73f
[Torch Dialect]Add Support for AtenGroupNormOp and AtenNativeGroupNormOp (#2591)
Co-authored-by: LiuYuanqiang <liuyuanqiang.yqliu@bytedance.com>
2023-12-13 11:05:12 +08:00
Eric Kunze f67249d34f
Sort the TOSA passing test list (#2630)
For easier tracking of issues, sort the TOSA passing list. It is still
significantly smaller then the XFAIL list would be.

Resolves #2620, at least until the xfail list gets smaller than the
passing list.

Signed-off-by: Eric Kunze <eric.kunze@arm.com>
2023-12-12 14:22:25 -08:00
JianzheXiao 96fcde4d77
[Torch Dialect] Support Einsum Op (#2230)
As title, support torch.aten.einsum op

Right now only support Static Shape, because of the known issue, the
fixed solution is here: https://github.com/llvm/torch-mlir/pull/2154

Co-authored-by: Jiawei Wu
[wujiawei.aml@bytedance.com](mailto:wujiawei.aml@bytedance.com)
2023-12-10 12:30:37 +08:00
Felix Schneider fb21a85874
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where

1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W

Now this has been fixed in
https://github.com/llvm/llvm-project/pull/73855 which broke the
torch-mlir lowering to that Op.

This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.

Fix https://github.com/llvm/torch-mlir/issues/2622
2023-12-08 14:18:23 +01:00
Stella Laurenzo 8252656b6d
Advance llvm-project and stablehlo. (#2619)
llvm-project: bbd2b08b95fe76bea138c1b03c1cd42ed3ee04df
stablehlo: ab709fe48de88c67717abfbd7ef17425eb95ddaf

These commits were chosen in order to account for an MLIR API break from
3dbac2c007
which required a patch to stablehlo. We integrate a bit beyond that
commit to deal with some revert/reapply cycles in the intervening range
which were discovered in another downstream.

Further, it requires adaptation to the stablehlo API breaks introduced
from https://github.com/openxla/stablehlo/pull/1872 which are along for
the ride.

Since some stablehlo builders were changed to directly take int64_t
array refs, also traced that up some call stacks to eliminate some
signed/unsigned mismatches that result.

Also adds a few TOSA tests to the passing set that seem to work now.
2023-12-07 23:13:42 -08:00
Frederik Harwath 6248216dca
Add aten.min.dim to linalg lowering (#2600) 2023-12-05 07:16:35 -08:00
Vivek Khandelwal 10b5432e7d build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-12-04.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-12-05 13:18:47 +05:30
James Newling 1b7d6f2af9
Improve decomposition of pixel_shuffle (support dynamic shapes) (#2590)
The aten.reshape ops in the decomposition are replaced with prims.collapse 
and prims.split_dim ops, which means that the cases where the lowering of
reshape from torch to linalg which are not supported, are avoided.

Essentially, by using the collapse and split_dim ops instead of the
reshape ops, we are not "losing" the information that the reshapes do not
arbitrarily mix dimensions. Which makes lowering easy. 

3 additional tests added: 
- fully dynamic, 
- dynamic only the spatial dimensions, 
- dynamic only in the non-spatial dimensions.
2023-11-22 12:31:06 -08:00
James Newling 03e8f99730
Lowering to linalg of prims split_dim op (#2576)
Adds support for lowering to prims split_op. 

Similar design to collapse op lowering in 
https://github.com/llvm/torch-mlir/pull/2572, with some 
small differences, because the split_dim op (in pytorch) is
view-changing whereas the collapse is not. The difference 
means that 

1) it must be registered in the function Torch::isViewLikeOp
2) it must be be added to the "expected fail" set for the torch dynamo backend.
2023-11-21 07:56:09 -08:00
Yuanqiang Liu facbe5d96b
[Torch Dialect] support AtenArangeStartOutOp in ReduceOpVariants like… (#2563)
… AtenBernoulli_FloatOp

It fixing case like: `%2110 = torch.aten.arange.start_out %int1,
%int1517, %int1, %2109 : !torch.int, !torch.int, !torch.int,
!torch.tensor -> !torch.tensor`.
`aten.arange.start_out` doesn't have value semantics also, means`%2110`
is an alias for %2109.
So I decompose it to `aten.arange.start` + `torch.contents.overwrite`.  
The complex decomposition logic is target to handle cases like view and
dtype cast which I add in e2e tests.
2023-11-17 00:51:55 +08:00
James Newling e81282ae8f
Support for prims collapse op (lowering to linalg) (#2572)
Steps taken:
1) add generator code to torch_ods_gen.py, run update_torch_ods.sh
2) add (custom) shape and type inference generator code to
abstract_interp_lib_gen.py, run update_abstract_interp_lib.sh
3) Implement lowering to tensor.collapse_dims. Requires the `start` and
`end` values to be constant, else lowering fails
4) Update xfail_sets.py (append to LTC_XFAIL_SET) after running
/tools/e2e_test.sh --filter Collapse --verbose -c XX for all support
backends (XX).

Motivation: 
- Supporting the collapse operation will be useful for lowering of
pixel_shuffle (see Issue #2559)
2023-11-15 08:34:38 -08:00
Yuanqiang Liu 3ab790c50a
[Torch Dialect] add canonicalize for aten.numel (#2562) 2023-11-11 12:16:53 +08:00
James Newling b6e551c7b8
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:

 ```
 @annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
 ```
 
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
 
 For dynamic cases, for example if the above is replaced with 
 
  ```
 @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
 ```

There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 08:52:44 -05:00
JianzheXiao a42d4c18ff
[Torch Dialect]Support aten.cosine_similarity (#2364)
As title, add support for aten.cosine_similarity, support broadcast
inputA/inputB to the same shape
2023-11-08 15:28:30 +08:00
Jiawei Wu d5ee8ee73a
[Torch Dialect] emit aten.reshape_as op and add decomposition pattern. (#2553) 2023-11-05 11:38:36 +08:00
Yuanqiang Liu 0378da0abd
[Torch Dialect] support aten.isinf (#2544)
Also fix linalg lowering from `UEQ` to `OEQ`.  
I will check other comparison's lowering later.
2023-11-04 22:26:01 +08:00
Stella Laurenzo 6961f0a247
Re-organize project structure to separate PyTorch dependencies from core project. (#2542)
This is a first step towards the structure we discussed here:
https://gist.github.com/stellaraccident/931b068aaf7fa56f34069426740ebf20

There are two primary goals:

1. Separate the core project (C++ dialects and conversions) from the
hard PyTorch dependencies. We move all such things into projects/pt1 as
a starting point since they are presently entangled with PT1-era APIs.
Additional work can be done to disentangle components from that
(specifically LTC is identified as likely ultimately living in a
`projects/ltc`).
2. Create space for native PyTorch2 Dynamo-based infra to be upstreamed
without needing to co-exist with the original TorchScript path.

Very little changes in this path with respect to build layering or
options. These can be updated in a followup without commingling
directory structure changes.

This also takes steps toward a couple of other layering enhancements:

* Removes the llvm-external-projects/torch-mlir-dialects sub-project,
collapsing it into the main tree.
* Audits and fixes up the core C++ build to account for issues found
while moving things. This is just an opportunistic pass through but
roughly ~halves the number of build actions for the project from the
high 4000's to the low 2000's.

It deviates from the discussed plan by having a `projects/` tree instead
of `compat/`. As I was thinking about it, this will better accommodate
the follow-on code movement.

Once things are roughly in place and the CI passing, followups will
focus on more in-situ fixes and cleanups.
2023-11-02 19:45:55 -07:00