torch-mlir/test/Conversion/TorchToStablehlo
Matthias Gehre 6678e1a256
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.

We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.

For example
```
module {
  func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
                        %arg1: !torch.vtensor<[336,168,3,3],f32>, 
                        %arg2: !torch.vtensor<[336],f32>) 
                        -> !torch.vtensor<[32,336,56,56],f32> {
    %false = torch.constant.bool false
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 
    : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
      !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
   -> !torch.vtensor<[32,336,56,56],f32>
    return %3 : !torch.vtensor<[32,336,56,56],f32>
  }
}
```
would result in
```
[...]
  %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
  %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
    ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
    outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 08:43:10 +02:00
..
basic.mlir [torch-mlir] bump stablehlo/llvm version (#3471) 2024-06-18 16:59:53 -07:00
elementwise.mlir TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00
gather.mlir [torch-mlir] bump stablehlo/llvm version (#3471) 2024-06-18 16:59:53 -07:00
linear.mlir TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00
lit.local.cfg [test] rename TorchToMhlo to TorchToStablehlo (#1995) 2023-04-03 18:41:25 -07:00
pooling.mlir [Stablehlo] Bump stablehlo to ab92adeda9119a6c3914cd42367b0a2b70765e91 (#3285) 2024-05-05 19:56:12 +08:00
rng.mlir [Stablehlo] fix aten.randn's lowering with f32 element type (#3329) 2024-05-11 17:40:04 +08:00
scatter.mlir [torch-mlir] bump stablehlo/llvm version (#3471) 2024-06-18 16:59:53 -07:00
view_like.mlir TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00