mirror of https://github.com/llvm/torch-mlir
[ONNX] fix padding for `onnx.MaxPool` (#3611)
The saga of aligning onnx and torch padding conventions continues. ```python onnx_pads = [low_x, low_y, low_z, high_x, high_y, high_z] torch_pads = [low_z, high_z, low_y, high_y, low_x, high_x] ``` So not only is the lexicographical ordering hierarchy swapped (low/high x spatial-dim -> spatial-dim x low/high) but the ordering in the the spatial-dim specification is also reversed. This patch properly reverses the pad ordering (and actually uses the `shuffledPadding` to pad).pull/3617/head
parent
6c33ab024e
commit
7f2a17e757
|
@ -788,15 +788,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
|
||||
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
|
||||
shuffledPadding.resize(2 * rank);
|
||||
for (int i = 0; i < spatial; ++i) {
|
||||
paddedShape[i + 2] += padding[i] + padding[i + spatial];
|
||||
shuffledPadding[2 * i] = padding[i];
|
||||
shuffledPadding[2 * i + 1] = padding[i + spatial];
|
||||
shuffledPadding[2 * i] = padding[spatial - i - 1];
|
||||
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
|
||||
}
|
||||
|
||||
Value shuffledPaddingList =
|
||||
createConstantIntList(binder, rewriter, padding);
|
||||
createConstantIntList(binder, rewriter, shuffledPadding);
|
||||
Value zero;
|
||||
if (isa<FloatType>(resultTypeOut.getDtype())) {
|
||||
zero = rewriter.create<Torch::ConstantFloatOp>(
|
||||
|
|
|
@ -670,8 +670,8 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
|
|||
// CHECK-LABEL: func.func @test_maxpool_pad
|
||||
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT2_0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308
|
||||
|
|
Loading…
Reference in New Issue