torch-mlir/projects
John Wu 704cfdaf08
Add aten.pool_max3d support to torch-to-linalg (#2735)
Added verification logic to the abstract_interpreter_lib_gen.py

Also made some unit tests

Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).

To achieve this, we would need to lower our code using the `linalg`
dialect.


Turns out the pooling code in `linalg` looks like this.

```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
                        %strides: memref<3xindex>, %dilations: memref<3xindex>) {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
    %C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
    %D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
    %H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
    %W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>

    %kernel_d = memref.load %K[%c0] : memref<3xindex>
    %kernel_h = memref.load %K[%c1] : memref<3xindex>
    %kernel_w = memref.load %K[2] : memref<3xindex>
    %stride_d = memref.load %strides[%c0] : memref<3xindex>
    %stride_h = memref.load %strides[%c1] : memref<3xindex>
    %stride_w = memref.load %strides[2] : memref<3xindex>
    %dilation_d = memref.load %dilations[%c0] : memref<3xindex>
    %dilation_h = memref.load %dilations[%c1] : memref<3xindex>
    %dilation_w = memref.load %dilations[2] : memref<3xindex>

    linalg.generic {
        indexing_maps = [
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>,  // Map for input tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>,                                              // Map for kernel tensor
            affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)>                                            // Map for output tensor
        ],
        iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
        doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
    } ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
        ^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
            %max_val = arith.maxf %input_elem, %output_elem : f32
            linalg.yield %max_val : f32
    }
    return
}

```

This was implemented based on it's source code with the adjustments
mentioned above:

4ca1b5e094/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (L5647)

Issues related to this can be found here

https://github.com/nod-ai/SHARK-Turbine/issues/324
2024-01-19 21:09:46 +05:30
..
jit_ir_common Breakup python pytorch deps (#2582) 2023-11-19 12:10:19 -08:00
ltc [torch][quant] Support quantize and dequantize for torch (#2731) 2024-01-12 19:11:14 -08:00
onnx_c_importer [onnx] Add torch-mlir-import-onnx native port as an optional tool/library. (#2694) 2023-12-27 12:13:34 -08:00
pt1 Add aten.pool_max3d support to torch-to-linalg (#2735) 2024-01-19 21:09:46 +05:30
CMakeLists.txt [onnx] Add torch-mlir-import-onnx native port as an optional tool/library. (#2694) 2023-12-27 12:13:34 -08:00