[Linalg] Bring back onnx AveragePool padding asymmetric support

pull/3455/head
AmosLewis 2024-06-13 03:42:06 +00:00
parent ae6f5e8251
commit 2f2dfb7e44
2 changed files with 92 additions and 90 deletions

View File

@ -441,17 +441,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>( cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i))); binder.getLoc(), rewriter.getI64IntegerAttr(i)));
} }
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] for (int64_t i : padding) {
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
// axes x.
int64_t paddingSizeHalf = padding.size() / 2;
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
// Check if onnx padding attribute is symmetric.
if (padding[i] != padding[i + paddingSizeHalf])
return rewriter.notifyMatchFailure(
binder.op, "onnx padding attribute is not symmetric");
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>( cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); binder.getLoc(), rewriter.getI64IntegerAttr(i)));
} }
for (int64_t i : strides) { for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>( cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(

View File

@ -641,7 +641,7 @@ public:
// Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support.
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) { if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
auto selfType = cast<RankedTensorType>(self.getType()); auto selfType = cast<RankedTensorType>(self.getType());
const int64_t selfRank = selfType.getRank(); unsigned selfRank = selfType.getRank();
int64_t wDim = toPositiveDim(-1, selfRank); int64_t wDim = toPositiveDim(-1, selfRank);
int64_t hDim = toPositiveDim(-2, selfRank); int64_t hDim = toPositiveDim(-2, selfRank);
Value inputHeight = getDimOp(rewriter, loc, self, hDim); Value inputHeight = getDimOp(rewriter, loc, self, hDim);
@ -657,6 +657,12 @@ public:
/*indexingMaps=*/indexingMapsAvg, /*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg, /*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType())) {
// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
divisor = adaptor.getDivisorOverride();
} else {
// The algorithm for computing the divisor with // The algorithm for computing the divisor with
// count_include_pad is manily based on pytorch // count_include_pad is manily based on pytorch
// implementation. The following code is comment // implementation. The following code is comment
@ -683,6 +689,15 @@ public:
loc, rewriter.getI64IntegerAttr(paddingInts[1])); loc, rewriter.getI64IntegerAttr(paddingInts[1]));
Value owDW = b.create<arith::MulIOp>(loc, ow, dW); Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW); Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
// onnx average pool may pass asymmetric padding,
// so modify the padding values to now represent high
// padding.
if (paddingInts.size() == 2 * (selfRank - 2)) {
padH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[2]));
padW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(paddingInts[3]));
}
// int64_t ih1 = std::min(ih0 + kH, input_height + padH); // int64_t ih1 = std::min(ih0 + kH, input_height + padH);
Value ih = castIndexToInt64(b, loc, inputHeight); Value ih = castIndexToInt64(b, loc, inputHeight);
Value ih0KH = b.create<arith::AddIOp>( Value ih0KH = b.create<arith::AddIOp>(
@ -725,18 +740,13 @@ public:
divisor = convertScalarToDtype(b, loc, poolSize, divisor = convertScalarToDtype(b, loc, poolSize,
resultElementType); resultElementType);
} else { } else {
Value ih1_ih0 = Value ih1_ih0 = b.create<arith::SubIOp>(loc, ih1Clamped,
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped); ih0Clamped);
Value iw1_iw0 = Value iw1_iw0 = b.create<arith::SubIOp>(loc, iw1Clamped,
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped); iw0Clamped);
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0); divisor =
b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
} }
// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
if (!isa<Torch::NoneType>(
op.getDivisorOverride().getType()))
divisor = adaptor.getDivisorOverride();
} }
divisor = convertScalarToDtype(b, loc, divisor, divisor = convertScalarToDtype(b, loc, divisor,