mirror of https://github.com/llvm/torch-mlir
704cfdaf08
Added verification logic to the abstract_interpreter_lib_gen.py
Also made some unit tests
Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).
To achieve this, we would need to lower our code using the `linalg`
dialect.
Turns out the pooling code in `linalg` looks like this.
```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
%strides: memref<3xindex>, %dilations: memref<3xindex>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
%C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
%D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
%H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
%W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>
%kernel_d = memref.load %K[%c0] : memref<3xindex>
%kernel_h = memref.load %K[%c1] : memref<3xindex>
%kernel_w = memref.load %K[2] : memref<3xindex>
%stride_d = memref.load %strides[%c0] : memref<3xindex>
%stride_h = memref.load %strides[%c1] : memref<3xindex>
%stride_w = memref.load %strides[2] : memref<3xindex>
%dilation_d = memref.load %dilations[%c0] : memref<3xindex>
%dilation_h = memref.load %dilations[%c1] : memref<3xindex>
%dilation_w = memref.load %dilations[2] : memref<3xindex>
linalg.generic {
indexing_maps = [
affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>, // Map for input tensor
affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>, // Map for kernel tensor
affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)> // Map for output tensor
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
} ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
%max_val = arith.maxf %input_elem, %output_elem : f32
linalg.yield %max_val : f32
}
return
}
```
This was implemented based on it's source code with the adjustments
mentioned above:
|
||
---|---|---|
.. | ||
jit_ir_common | ||
ltc | ||
onnx_c_importer | ||
pt1 | ||
CMakeLists.txt |