torch-mlir/test/Conversion/TorchToLinalg
Matthias Gehre 6678e1a256
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.

We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.

For example
```
module {
  func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
                        %arg1: !torch.vtensor<[336,168,3,3],f32>, 
                        %arg2: !torch.vtensor<[336],f32>) 
                        -> !torch.vtensor<[32,336,56,56],f32> {
    %false = torch.constant.bool false
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 
    : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
      !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
   -> !torch.vtensor<[32,336,56,56],f32>
    return %3 : !torch.vtensor<[32,336,56,56],f32>
  }
}
```
would result in
```
[...]
  %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
  %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
    ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
    outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 08:43:10 +02:00
..
basic.mlir [torch-mlir] bump stablehlo/llvm version (#3471) 2024-06-18 16:59:53 -07:00
broadcast.mlir [TorchToLinalg] Improve broadcast lowerings in strict symbolic modes (#2505) 2023-10-05 15:15:26 -04:00
convolution.mlir [TorchToLinalg] Fix Quantized Convolution Accumulator Type (#3459) 2024-06-20 13:54:20 -07:00
elementwise.mlir TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00
flatten.mlir Integrate llvm-project at dabdec1001dc368373dd581cf72f37a440873ce3 (#3300) 2024-05-08 14:43:06 -04:00
gridsampler.mlir [onnx] Gridsampler addition of nearest mode (#3320) 2024-05-10 11:42:10 -07:00
pooling.mlir TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00
resize.mlir [ONNX] Fix resize ceil numerics and add half_pixel_symmetric support (#3443) 2024-06-11 22:35:50 -05:00
sparse.mlir [torch-mlir] bump stablehlo/llvm version (#3471) 2024-06-18 16:59:53 -07:00
unsqueeze.mlir Integrate llvm-project at dabdec1001dc368373dd581cf72f37a440873ce3 (#3300) 2024-05-08 14:43:06 -04:00
view.mlir [linalg] Implement strict mode lowering for aten.view. (#3319) 2024-05-10 13:45:50 -07:00
view_strict.mlir TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00