[MLIR][Torch] Fix OnnxToLinalg lowering for AvgPool op (#3076)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3098/head
Vivek Khandelwal 2024-04-01 22:14:14 +05:30 committed by GitHub
parent 282e9b0e64
commit 6844c84702
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 14 deletions

View File

@ -310,7 +310,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"AveragePool", 19,
"AveragePool", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
SmallVector<int64_t> dilation;
@ -361,7 +361,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op,
"padding list size does not match twice the number of axes");
}
if (binder.s64IntegerArrayAttr(strides, "strides", {1})) {
if (binder.s64IntegerArrayAttr(
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
if (strides.size() != 1 && strides.size() != rank - 2) {

View File

@ -114,8 +114,22 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
SmallVectorImpl<int64_t> &paddingInts,
Value initValue) {
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
unsigned selfRank = self.getType().cast<RankedTensorType>().getRank();
unsigned paddingIntsSize = paddingInts.size();
if (paddingIntsSize == 2 * (selfRank - 2)) {
// This condition being true means that the `paddingInts` contain seperate
// values for low padding and high padding.
for (unsigned i = 0; i < paddingIntsSize / 2; i++)
lowPaddingIncludingNC.push_back(paddingInts[i]);
for (unsigned i = paddingIntsSize / 2; i < paddingIntsSize; i++)
highPaddingIncludingNC.push_back(paddingInts[i]);
} else {
lowPaddingIncludingNC.append(paddingInts);
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;
highPaddingIncludingNC = lowPaddingIncludingNC;
}
if (ceilMode) {
for (int64_t i = 0; i < dimensionality; ++i) {

View File

@ -1908,18 +1908,8 @@ ONNX_XFAIL_SET = {
"LinalgNormModule_basic",
# Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
# Failure - onnx_lowering: onnx.Cast
"BucketizeTensorOutInt32RightModule_basic",