torch-mlir/lib/Dialect/Torch/Transforms
Xinyu Yang d4313eed4a
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
..
AbstractInterpLibrary.cpp [Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075) 2024-04-18 06:27:51 +08:00
AdjustCallingConventions.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
CMakeLists.txt [torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055) 2024-03-26 12:41:40 -07:00
DecomposeComplexOps.cpp [Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075) 2024-04-18 06:27:51 +08:00
DropAbstractInterpCalculations.cpp Update to LLVM 029313cc979ae71877b65794b1063d4e51184cc8 2023-03-21 04:16:20 -07:00
EraseModuleInitializer.cpp update llvm to d23516e9ad477527a9db4d06b1fa9566680ac67c (#1812) 2023-01-23 16:34:22 -08:00
FuseQuantizedOps.cpp Adds Some Quantization Support for AtenMatmulOp (#3147) 2024-04-15 16:06:47 -07:00
GlobalizeObjectGraph.cpp Clang format refresh (#2812) 2024-01-29 12:59:33 -05:00
InlineGlobalSlots.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
LowerToBackendContract.cpp [Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075) 2024-04-18 06:27:51 +08:00
MatchQuantizedOps.cpp [torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750) 2024-01-24 14:02:50 -08:00
MaximizeValueSemantics.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
PassDetail.h llvm: bump tag to e1318078 (#781) 2022-04-26 12:27:51 -07:00
Passes.cpp Replace RefineTypes with dtype functions (#2105) 2023-05-12 13:40:45 -07:00
PrepareForGlobalizeObjectGraph.cpp update llvm to d23516e9ad477527a9db4d06b1fa9566680ac67c (#1812) 2023-01-23 16:34:22 -08:00
RecomposeComplexOps.cpp [Torch Dialect] support aten.split_with_sizes (#2431) 2023-09-04 09:59:26 +08:00
ReduceOpVariants.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
RefinePublicReturn.cpp Support `DerefineOp` in `RefinePublicReturn`. 2023-07-20 20:08:46 +02:00
ReifyAbstractInterpCalculationsUtils.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
ReifyAbstractInterpCalculationsUtils.h handles 2,3,4 from https://github.com/llvm/torch-mlir/issues/1963 (#1964) 2023-03-24 21:50:01 -05:00
ReifyDtypeCalculations.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
ReifyShapeCalculations.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
ScalarizeShapes.cpp [torch] Improve shape inference for dynamic shapes (#3091) 2024-04-02 16:19:57 -07:00
SimplifyAbstractInterpCalculationsUtils.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
SimplifyAbstractInterpCalculationsUtils.h Replace RefineTypes with dtype functions (#2105) 2023-05-12 13:40:45 -07:00
SimplifyDtypeCalculations.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00
SimplifyShapeCalculations.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3130) 2024-04-11 06:47:35 -07:00