torch-mlir/lib/Dialect/TorchConversion/Transforms
Matthias Gehre 6678e1a256
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.

We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.

For example
```
module {
  func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
                        %arg1: !torch.vtensor<[336,168,3,3],f32>, 
                        %arg2: !torch.vtensor<[336],f32>) 
                        -> !torch.vtensor<[32,336,56,56],f32> {
    %false = torch.constant.bool false
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 
    : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
      !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
   -> !torch.vtensor<[32,336,56,56],f32>
    return %3 : !torch.vtensor<[32,336,56,56],f32>
  }
}
```
would result in
```
[...]
  %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
  %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
    ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
    outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 08:43:10 +02:00
..
BackendTypeConversion.cpp TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475) 2024-06-27 08:43:10 +02:00
BackendTypeConversionPasses.cpp [Stablehlo] support uint8 (#3367) 2024-06-04 09:04:59 +08:00
CMakeLists.txt [NFC reformat] Run pre-commit on all files and format misc. 2024-04-27 14:08:09 -07:00
ConvertCustomQuantOp.cpp [NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
PassDetail.h Migrate passes in TorchConversion to use FunctionOpInterface. (#2935) 2024-02-20 08:54:02 -08:00
Passes.cpp [Stablehlo] support uint8 (#3367) 2024-06-04 09:04:59 +08:00
UnpackQuantTensor.cpp Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243) 2024-04-27 14:00:56 -07:00
VerifyLinalgOnTensorsBackendContract.cpp [NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00
VerifyStablehloBackendContract.cpp [stablehlo] verify stablehlo backend contract (#3338) 2024-05-16 11:03:43 +08:00
VerifyTosaBackendContract.cpp [NFC] Remove unused header files (#3386) 2024-05-30 14:30:36 +08:00