[Linalg] Bring back onnx AveragePool padding asymmetric support

pull/3455/head
AmosLewis 2024-06-13 03:42:06 +00:00
parent ae6f5e8251
commit 2f2dfb7e44
2 changed files with 92 additions and 90 deletions

View File

@ -441,17 +441,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
// axes x.
int64_t paddingSizeHalf = padding.size() / 2;
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
// Check if onnx padding attribute is symmetric.
if (padding[i] != padding[i + paddingSizeHalf])
return rewriter.notifyMatchFailure(
binder.op, "onnx padding attribute is not symmetric");
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(

View File

@ -641,7 +641,7 @@ public:
// Case1: AtenAvgPool1d/2dOp with countIncludePad=false support.
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
auto selfType = cast<RankedTensorType>(self.getType());
const int64_t selfRank = selfType.getRank();
unsigned selfRank = selfType.getRank();
int64_t wDim = toPositiveDim(-1, selfRank);
int64_t hDim = toPositiveDim(-2, selfRank);
Value inputHeight = getDimOp(rewriter, loc, self, hDim);
@ -657,86 +657,96 @@ public:
/*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) {
// The algorithm for computing the divisor with
// count_include_pad is manily based on pytorch
// implementation. The following code is comment
// with pytorch code.
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
Value indexOh =
b.create<linalg::IndexOp>(loc, /*value=*/dimH);
Value oh = castIndexToInt64(b, loc, indexOh);
Value indexOw =
b.create<linalg::IndexOp>(loc, /*value=*/dimW);
Value ow = castIndexToInt64(b, loc, indexOw);
// int64_t ih0 = oh * dH - padH;
Value dH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[0]));
Value padH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[0]));
Value ohDH = b.create<arith::MulIOp>(loc, oh, dH);
Value ih0 = b.create<arith::SubIOp>(loc, ohDH, padH);
// int64_t iw0 = ow * dW - padW;
Value dW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[1]));
Value padW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
Value ih = castIndexToInt64(b, loc, inputHeight);
Value ih0KH = b.create<arith::AddIOp>(
loc, ih0, kernelSizeIntValues[0]);
Value ihPadH = b.create<arith::AddIOp>(loc, ih, padH);
Value ih1 = b.create<arith::MinSIOp>(loc, ih0KH, ihPadH);
// int64_t iw1 = std::min(iw0 + kW, input_width + padW);
Value iw = castIndexToInt64(b, loc, inputWidth);
Value iw0KW = b.create<arith::AddIOp>(
loc, iw0, kernelSizeIntValues[1]);
Value iwPadW = b.create<arith::AddIOp>(loc, iw, padW);
Value iw1 = b.create<arith::MinSIOp>(loc, iw0KW, iwPadW);
// int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
Value ih1Ih0 = b.create<arith::SubIOp>(loc, ih1, ih0);
Value iw1Iw0 = b.create<arith::SubIOp>(loc, iw1, iw0);
Value poolSize =
b.create<arith::MulIOp>(loc, ih1Ih0, iw1Iw0);
// ih0 = std::max(ih0, 0);
Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
Value ih0Clamped =
b.create<arith::MaxSIOp>(loc, ih0, cstZero);
// iw0 = std::max(iw0, 0);
Value iw0Clamped =
b.create<arith::MaxSIOp>(loc, iw0, cstZero);
// ih1 = std::min(ih1, input_height);
Value ih1Clamped = b.create<arith::MinSIOp>(loc, ih1, ih);
// iw1 = std::min(iw1, input_width);
Value iw1Clamped = b.create<arith::MinSIOp>(loc, iw1, iw);
// if (divisor_override.has_value()) {
// divisor = divisor_override.value();
// } else {
// if(count_include_pad) {
// divisor = pool_size;
// } else {
// divisor = (ih1 - ih0) * (iw1 - iw0);
// }
// }
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType())) {
// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
divisor = adaptor.getDivisorOverride();
} else {
Value ih1_ih0 =
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
Value iw1_iw0 =
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType()))
divisor = adaptor.getDivisorOverride();
// The algorithm for computing the divisor with
// count_include_pad is manily based on pytorch
// implementation. The following code is comment
// with pytorch code.
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
Value indexOh =
b.create<linalg::IndexOp>(loc, /*value=*/dimH);
Value oh = castIndexToInt64(b, loc, indexOh);
Value indexOw =
b.create<linalg::IndexOp>(loc, /*value=*/dimW);
Value ow = castIndexToInt64(b, loc, indexOw);
// int64_t ih0 = oh * dH - padH;
Value dH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[0]));
Value padH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[0]));
Value ohDH = b.create<arith::MulIOp>(loc, oh, dH);
Value ih0 = b.create<arith::SubIOp>(loc, ohDH, padH);
// int64_t iw0 = ow * dW - padW;
Value dW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(strideInts[1]));
Value padW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
// onnx average pool may pass asymmetric padding,
// so modify the padding values to now represent high
// padding.
if (paddingInts.size() == 2 * (selfRank - 2)) {
padH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[2]));
padW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[3]));
}
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
Value ih = castIndexToInt64(b, loc, inputHeight);
Value ih0KH = b.create<arith::AddIOp>(
loc, ih0, kernelSizeIntValues[0]);
Value ihPadH = b.create<arith::AddIOp>(loc, ih, padH);
Value ih1 = b.create<arith::MinSIOp>(loc, ih0KH, ihPadH);
// int64_t iw1 = std::min(iw0 + kW, input_width + padW);
Value iw = castIndexToInt64(b, loc, inputWidth);
Value iw0KW = b.create<arith::AddIOp>(
loc, iw0, kernelSizeIntValues[1]);
Value iwPadW = b.create<arith::AddIOp>(loc, iw, padW);
Value iw1 = b.create<arith::MinSIOp>(loc, iw0KW, iwPadW);
// int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
Value ih1Ih0 = b.create<arith::SubIOp>(loc, ih1, ih0);
Value iw1Iw0 = b.create<arith::SubIOp>(loc, iw1, iw0);
Value poolSize =
b.create<arith::MulIOp>(loc, ih1Ih0, iw1Iw0);
// ih0 = std::max(ih0, 0);
Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(0));
Value ih0Clamped =
b.create<arith::MaxSIOp>(loc, ih0, cstZero);
// iw0 = std::max(iw0, 0);
Value iw0Clamped =
b.create<arith::MaxSIOp>(loc, iw0, cstZero);
// ih1 = std::min(ih1, input_height);
Value ih1Clamped = b.create<arith::MinSIOp>(loc, ih1, ih);
// iw1 = std::min(iw1, input_width);
Value iw1Clamped = b.create<arith::MinSIOp>(loc, iw1, iw);
// if (divisor_override.has_value()) {
// divisor = divisor_override.value();
// } else {
// if(count_include_pad) {
// divisor = pool_size;
// } else {
// divisor = (ih1 - ih0) * (iw1 - iw0);
// }
// }
if (countIncludePad) {
divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType);
} else {
Value ih1_ih0 = b.create<arith::SubIOp>(loc, ih1Clamped,
ih0Clamped);
Value iw1_iw0 = b.create<arith::SubIOp>(loc, iw1Clamped,
iw0Clamped);
divisor =
b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
}
}
divisor = convertScalarToDtype(b, loc, divisor,