From 2f2dfb7e44d1bb31420ccf52784d7b5ab7b5e6f2 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Thu, 13 Jun 2024 03:42:06 +0000 Subject: [PATCH] [Linalg] Bring back onnx AveragePool padding asymmetric support --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 12 +- lib/Conversion/TorchToLinalg/Pooling.cpp | 170 +++++++++--------- 2 files changed, 92 insertions(+), 90 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index dcb28129a..d0ff6e973 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -441,17 +441,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] - // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all - // axes x. - int64_t paddingSizeHalf = padding.size() / 2; - for (int64_t i = 0; i < paddingSizeHalf; ++i) { - // Check if onnx padding attribute is symmetric. - if (padding[i] != padding[i + paddingSizeHalf]) - return rewriter.notifyMatchFailure( - binder.op, "onnx padding attribute is not symmetric"); + for (int64_t i : padding) { cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index d80f3d427..caa8d9a7f 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -641,7 +641,7 @@ public: // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. if constexpr (std::is_same()) { auto selfType = cast(self.getType()); - const int64_t selfRank = selfType.getRank(); + unsigned selfRank = selfType.getRank(); int64_t wDim = toPositiveDim(-1, selfRank); int64_t hDim = toPositiveDim(-2, selfRank); Value inputHeight = getDimOp(rewriter, loc, self, hDim); @@ -657,86 +657,96 @@ public: /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { - // The algorithm for computing the divisor with - // count_include_pad is manily based on pytorch - // implementation. The following code is comment - // with pytorch code. - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - Value indexOh = - b.create(loc, /*value=*/dimH); - Value oh = castIndexToInt64(b, loc, indexOh); - Value indexOw = - b.create(loc, /*value=*/dimW); - Value ow = castIndexToInt64(b, loc, indexOw); - - // int64_t ih0 = oh * dH - padH; - Value dH = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[0])); - Value padH = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[0])); - Value ohDH = b.create(loc, oh, dH); - Value ih0 = b.create(loc, ohDH, padH); - // int64_t iw0 = ow * dW - padW; - Value dW = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[1])); - Value padW = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[1])); - Value owDW = b.create(loc, ow, dW); - Value iw0 = b.create(loc, owDW, padW); - // int64_t ih1 = std::min(ih0 + kH, input_height + padH); - Value ih = castIndexToInt64(b, loc, inputHeight); - Value ih0KH = b.create( - loc, ih0, kernelSizeIntValues[0]); - Value ihPadH = b.create(loc, ih, padH); - Value ih1 = b.create(loc, ih0KH, ihPadH); - // int64_t iw1 = std::min(iw0 + kW, input_width + padW); - Value iw = castIndexToInt64(b, loc, inputWidth); - Value iw0KW = b.create( - loc, iw0, kernelSizeIntValues[1]); - Value iwPadW = b.create(loc, iw, padW); - Value iw1 = b.create(loc, iw0KW, iwPadW); - // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); - Value ih1Ih0 = b.create(loc, ih1, ih0); - Value iw1Iw0 = b.create(loc, iw1, iw0); - Value poolSize = - b.create(loc, ih1Ih0, iw1Iw0); - // ih0 = std::max(ih0, 0); - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value ih0Clamped = - b.create(loc, ih0, cstZero); - // iw0 = std::max(iw0, 0); - Value iw0Clamped = - b.create(loc, iw0, cstZero); - // ih1 = std::min(ih1, input_height); - Value ih1Clamped = b.create(loc, ih1, ih); - // iw1 = std::min(iw1, input_width); - Value iw1Clamped = b.create(loc, iw1, iw); - // if (divisor_override.has_value()) { - // divisor = divisor_override.value(); - // } else { - // if(count_include_pad) { - // divisor = pool_size; - // } else { - // divisor = (ih1 - ih0) * (iw1 - iw0); - // } - // } - if (countIncludePad) { - divisor = convertScalarToDtype(b, loc, poolSize, - resultElementType); + if (!isa( + op.getDivisorOverride().getType())) { + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + divisor = adaptor.getDivisorOverride(); } else { - Value ih1_ih0 = - b.create(loc, ih1Clamped, ih0Clamped); - Value iw1_iw0 = - b.create(loc, iw1Clamped, iw0Clamped); - divisor = b.create(loc, ih1_ih0, iw1_iw0); - } - // AtenAvgPool2/3dOp has an optional divisor_override - // attribute while AtenAvgPool1dOp does not. - if constexpr (std::is_same()) { - if (!isa( - op.getDivisorOverride().getType())) - divisor = adaptor.getDivisorOverride(); + // The algorithm for computing the divisor with + // count_include_pad is manily based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + Value indexOh = + b.create(loc, /*value=*/dimH); + Value oh = castIndexToInt64(b, loc, indexOh); + Value indexOw = + b.create(loc, /*value=*/dimW); + Value ow = castIndexToInt64(b, loc, indexOw); + + // int64_t ih0 = oh * dH - padH; + Value dH = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[0])); + Value padH = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[0])); + Value ohDH = b.create(loc, oh, dH); + Value ih0 = b.create(loc, ohDH, padH); + // int64_t iw0 = ow * dW - padW; + Value dW = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[1])); + Value padW = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[1])); + Value owDW = b.create(loc, ow, dW); + Value iw0 = b.create(loc, owDW, padW); + // onnx average pool may pass asymmetric padding, + // so modify the padding values to now represent high + // padding. + if (paddingInts.size() == 2 * (selfRank - 2)) { + padH = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[2])); + padW = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[3])); + } + // int64_t ih1 = std::min(ih0 + kH, input_height + padH); + Value ih = castIndexToInt64(b, loc, inputHeight); + Value ih0KH = b.create( + loc, ih0, kernelSizeIntValues[0]); + Value ihPadH = b.create(loc, ih, padH); + Value ih1 = b.create(loc, ih0KH, ihPadH); + // int64_t iw1 = std::min(iw0 + kW, input_width + padW); + Value iw = castIndexToInt64(b, loc, inputWidth); + Value iw0KW = b.create( + loc, iw0, kernelSizeIntValues[1]); + Value iwPadW = b.create(loc, iw, padW); + Value iw1 = b.create(loc, iw0KW, iwPadW); + // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); + Value ih1Ih0 = b.create(loc, ih1, ih0); + Value iw1Iw0 = b.create(loc, iw1, iw0); + Value poolSize = + b.create(loc, ih1Ih0, iw1Iw0); + // ih0 = std::max(ih0, 0); + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value ih0Clamped = + b.create(loc, ih0, cstZero); + // iw0 = std::max(iw0, 0); + Value iw0Clamped = + b.create(loc, iw0, cstZero); + // ih1 = std::min(ih1, input_height); + Value ih1Clamped = b.create(loc, ih1, ih); + // iw1 = std::min(iw1, input_width); + Value iw1Clamped = b.create(loc, iw1, iw); + // if (divisor_override.has_value()) { + // divisor = divisor_override.value(); + // } else { + // if(count_include_pad) { + // divisor = pool_size; + // } else { + // divisor = (ih1 - ih0) * (iw1 - iw0); + // } + // } + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = b.create(loc, ih1Clamped, + ih0Clamped); + Value iw1_iw0 = b.create(loc, iw1Clamped, + iw0Clamped); + divisor = + b.create(loc, ih1_ih0, iw1_iw0); + } } divisor = convertScalarToDtype(b, loc, divisor,