[onnx] Fix `pool` lowering for non-symmetric padding (#2837)

`torch` requires that padding be symmetric for pooling operations. To
support non-symmetric pad we need to separately materialize out the
padding operation.

---------

Co-authored-by: James Newling <james.newling@gmail.com>
pull/2850/head
Rob Suderman 2024-02-01 14:35:21 -08:00 committed by GitHub
parent c7d7d7f004
commit 29baa813bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 26 deletions

View File

@ -262,30 +262,87 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
int64_t rank = *maybeRank;
int64_t spatial = rank - 2;
SmallVector<int64_t> kernel, padding, strides, dilations;
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
return rewriter.notifyMatchFailure(binder.op,
"kernel_shape bind failure");
if (kernel.size() != rank - 2)
if (kernel.size() != static_cast<size_t>(spatial))
return rewriter.notifyMatchFailure(
binder.op, "kernel list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
if (padding.size() != 1 && padding.size() != 2 * (rank - 2))
if (!padding.empty() &&
padding.size() != static_cast<size_t>(2 * spatial))
return rewriter.notifyMatchFailure(
binder.op, "padding list must contain (begin,end) pair for each "
"spatial axis");
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
if (strides.size() != 1 && strides.size() != rank - 2)
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
return rewriter.notifyMatchFailure(binder.op,
"dilations bind failure");
if (padding.empty())
padding.resize(spatial, 0);
if (strides.empty())
strides.resize(spatial, 1);
if (dilations.empty())
dilations.resize(spatial, 1);
// If the padding is symmetric we can push the padding operation to the
// torch operator.
if (padding.size() == static_cast<size_t>(2 * spatial)) {
bool equal = true;
for (int i = 0; i < spatial; ++i) {
equal = equal && (padding[i] == padding[i + spatial]);
}
if (equal)
padding.resize(spatial);
}
// Torch pool operators require equal padding on each size of each
// dimension so we materialize the padding behavior explicitly and set
// the padding to 0.
if (padding.size() == static_cast<size_t>(2 * spatial)) {
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
shuffledPadding.resize(2 * rank);
for (int i = 0; i < spatial; ++i) {
paddedShape[i + 2] += padding[i] + padding[i + spatial];
shuffledPadding[2 * i] = padding[i];
shuffledPadding[2 * i + 1] = padding[i + spatial];
}
Value shuffledPaddingList =
createConstantIntList(binder, rewriter, padding);
Value zero;
if (resultType.getDtype().isa<FloatType>()) {
zero = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(
std::numeric_limits<double>::lowest()));
} else if (resultType.getDtype().isa<IntegerType>()) {
zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
std::numeric_limits<int64_t>::lowest()));
}
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
paddedShape, operandTy.getDtype());
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
zero);
padding.clear();
padding.resize(spatial, 0);
}
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesList = createConstantIntList(binder, rewriter, strides);

View File

@ -251,13 +251,17 @@ func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !t
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
// CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I1_0:.*]] = torch.constant.int 1
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
// CHECK: %[[I1_3:.*]] = torch.constant.int 1
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[I1_2]], %[[I1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31],f32>
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST3]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31],f32>
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32>
return %0 : !torch.vtensor<[1,3,31,31],f32>
}
@ -269,12 +273,15 @@ func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[I3_1:.*]] = torch.constant.int 3
// CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
// CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[I1_0:.*]] = torch.constant.int 1
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,2,2],f32>
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32>
@ -289,11 +296,18 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
// CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I1_0:.*]] = torch.constant.int 1
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[I1_3:.*]] = torch.constant.int 1
// CHECK: %[[I1_4:.*]] = torch.constant.int 1
// CHECK: %[[I1_5:.*]] = torch.constant.int 1
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_3]], %[[I1_4]], %[[I1_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32>
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32>
@ -303,21 +317,52 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
// -----
// CHECK-LABEL: func.func @test_maxpool_pad
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
// CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308
// CHECK: %[[PADDED:.+]] = torch.aten.constant_pad_nd %arg0, %[[PADI]], %[[MIN]] : !torch.vtensor<[1,64,111,111],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,64,114,114],f32>
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %[[PADDED]], %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,114,114],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32>
return %0 : !torch.vtensor<[1,64,56,56],f32>
}
// -----
// CHECK-LABEL: func.func @test_maxpool_symmetric_pad
func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32>
return %0 : !torch.vtensor<[1,64,56,56],f32>