mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `pool` lowering for non-symmetric padding (#2837)
`torch` requires that padding be symmetric for pooling operations. To support non-symmetric pad we need to separately materialize out the padding operation. --------- Co-authored-by: James Newling <james.newling@gmail.com>pull/2850/head
parent
c7d7d7f004
commit
29baa813bd
|
@ -262,30 +262,87 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: unranked tensor");
|
||||
unsigned rank = *maybeRank;
|
||||
int64_t rank = *maybeRank;
|
||||
int64_t spatial = rank - 2;
|
||||
|
||||
SmallVector<int64_t> kernel, padding, strides, dilations;
|
||||
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"kernel_shape bind failure");
|
||||
if (kernel.size() != rank - 2)
|
||||
if (kernel.size() != static_cast<size_t>(spatial))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "kernel list size does not match the number of axes");
|
||||
if (binder.s64IntegerArrayAttr(padding, "pads", {0}))
|
||||
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
|
||||
if (padding.size() != 1 && padding.size() != 2 * (rank - 2))
|
||||
if (!padding.empty() &&
|
||||
padding.size() != static_cast<size_t>(2 * spatial))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "padding list must contain (begin,end) pair for each "
|
||||
"spatial axis");
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", {1}))
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
|
||||
if (strides.size() != 1 && strides.size() != rank - 2)
|
||||
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "strides list size does not match the number of axes");
|
||||
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"dilations bind failure");
|
||||
|
||||
if (padding.empty())
|
||||
padding.resize(spatial, 0);
|
||||
if (strides.empty())
|
||||
strides.resize(spatial, 1);
|
||||
if (dilations.empty())
|
||||
dilations.resize(spatial, 1);
|
||||
|
||||
// If the padding is symmetric we can push the padding operation to the
|
||||
// torch operator.
|
||||
if (padding.size() == static_cast<size_t>(2 * spatial)) {
|
||||
bool equal = true;
|
||||
for (int i = 0; i < spatial; ++i) {
|
||||
equal = equal && (padding[i] == padding[i + spatial]);
|
||||
}
|
||||
if (equal)
|
||||
padding.resize(spatial);
|
||||
}
|
||||
|
||||
// Torch pool operators require equal padding on each size of each
|
||||
// dimension so we materialize the padding behavior explicitly and set
|
||||
// the padding to 0.
|
||||
if (padding.size() == static_cast<size_t>(2 * spatial)) {
|
||||
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
|
||||
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
|
||||
shuffledPadding.resize(2 * rank);
|
||||
for (int i = 0; i < spatial; ++i) {
|
||||
paddedShape[i + 2] += padding[i] + padding[i + spatial];
|
||||
shuffledPadding[2 * i] = padding[i];
|
||||
shuffledPadding[2 * i + 1] = padding[i + spatial];
|
||||
}
|
||||
|
||||
Value shuffledPaddingList =
|
||||
createConstantIntList(binder, rewriter, padding);
|
||||
Value zero;
|
||||
if (resultType.getDtype().isa<FloatType>()) {
|
||||
zero = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(
|
||||
std::numeric_limits<double>::lowest()));
|
||||
} else if (resultType.getDtype().isa<IntegerType>()) {
|
||||
zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
std::numeric_limits<int64_t>::lowest()));
|
||||
}
|
||||
|
||||
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
paddedShape, operandTy.getDtype());
|
||||
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
|
||||
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
|
||||
zero);
|
||||
padding.clear();
|
||||
padding.resize(spatial, 0);
|
||||
}
|
||||
|
||||
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
|
||||
Value paddingList = createConstantIntList(binder, rewriter, padding);
|
||||
Value stridesList = createConstantIntList(binder, rewriter, strides);
|
||||
|
|
|
@ -251,13 +251,17 @@ func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !t
|
|||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[I1_2]], %[[I1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31],f32>
|
||||
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST3]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32>
|
||||
return %0 : !torch.vtensor<[1,3,31,31],f32>
|
||||
}
|
||||
|
@ -269,12 +273,15 @@ func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.
|
|||
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[I3_1:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[I1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,2,2],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32>
|
||||
|
@ -289,11 +296,18 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
|
|||
// CHECK: %[[I2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[I1_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[I1_5:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_3]], %[[I1_4]], %[[I1_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||
|
@ -303,21 +317,52 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) ->
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_pad
|
||||
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT2_0:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[INT2_1:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308
|
||||
// CHECK: %[[PADDED:.+]] = torch.aten.constant_pad_nd %arg0, %[[PADI]], %[[MIN]] : !torch.vtensor<[1,64,111,111],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,64,114,114],f32>
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %[[PADDED]], %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,114,114],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32>
|
||||
return %0 : !torch.vtensor<[1,64,56,56],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_symmetric_pad
|
||||
func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT3_0:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT2_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[INT1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56,56],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32>
|
||||
return %0 : !torch.vtensor<[1,64,56,56],f32>
|
||||
|
|
Loading…
Reference in New Issue