[ONNX] fix padding for `onnx.MaxPool` (#3611)

The saga of aligning onnx and torch padding conventions continues. 

```python
onnx_pads = [low_x, low_y, low_z, high_x, high_y, high_z]
torch_pads = [low_z, high_z, low_y, high_y, low_x, high_x]
```

So not only is the lexicographical ordering hierarchy swapped (low/high
x spatial-dim -> spatial-dim x low/high) but the ordering in the the
spatial-dim specification is also reversed.

This patch properly reverses the pad ordering (and actually uses the
`shuffledPadding` to pad).
pull/3617/head
zjgarvey 2024-08-07 20:34:00 -07:00 committed by GitHub
parent 6c33ab024e
commit 7f2a17e757
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View File

@ -788,15 +788,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
shuffledPadding.resize(2 * rank);
for (int i = 0; i < spatial; ++i) {
paddedShape[i + 2] += padding[i] + padding[i + spatial];
shuffledPadding[2 * i] = padding[i];
shuffledPadding[2 * i + 1] = padding[i + spatial];
shuffledPadding[2 * i] = padding[spatial - i - 1];
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
}
Value shuffledPaddingList =
createConstantIntList(binder, rewriter, padding);
createConstantIntList(binder, rewriter, shuffledPadding);
Value zero;
if (isa<FloatType>(resultTypeOut.getDtype())) {
zero = rewriter.create<Torch::ConstantFloatOp>(

View File

@ -670,8 +670,8 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
// CHECK-LABEL: func.func @test_maxpool_pad
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
// CHECK: %[[INT1_1:.+]] = torch.constant.int 2
// CHECK: %[[INT2_0:.+]] = torch.constant.int 1
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308