mirror of https://github.com/llvm/torch-mlir
[MLIR][Torch] Fix OnnxToLinalg lowering for AvgPool op (#3076)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3098/head
parent
282e9b0e64
commit
6844c84702
|
@ -310,7 +310,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"AveragePool", 19,
|
||||
"AveragePool", 11,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
SmallVector<int64_t> dilation;
|
||||
|
@ -361,7 +361,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op,
|
||||
"padding list size does not match twice the number of axes");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", {1})) {
|
||||
if (binder.s64IntegerArrayAttr(
|
||||
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
|
||||
return failure();
|
||||
}
|
||||
if (strides.size() != 1 && strides.size() != rank - 2) {
|
||||
|
|
|
@ -114,8 +114,22 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
|
|||
SmallVectorImpl<int64_t> &paddingInts,
|
||||
Value initValue) {
|
||||
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
||||
lowPaddingIncludingNC.append(paddingInts);
|
||||
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;
|
||||
SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
|
||||
|
||||
unsigned selfRank = self.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned paddingIntsSize = paddingInts.size();
|
||||
|
||||
if (paddingIntsSize == 2 * (selfRank - 2)) {
|
||||
// This condition being true means that the `paddingInts` contain seperate
|
||||
// values for low padding and high padding.
|
||||
for (unsigned i = 0; i < paddingIntsSize / 2; i++)
|
||||
lowPaddingIncludingNC.push_back(paddingInts[i]);
|
||||
for (unsigned i = paddingIntsSize / 2; i < paddingIntsSize; i++)
|
||||
highPaddingIncludingNC.push_back(paddingInts[i]);
|
||||
} else {
|
||||
lowPaddingIncludingNC.append(paddingInts);
|
||||
highPaddingIncludingNC = lowPaddingIncludingNC;
|
||||
}
|
||||
|
||||
if (ceilMode) {
|
||||
for (int64_t i = 0; i < dimensionality; ++i) {
|
||||
|
|
|
@ -1908,18 +1908,8 @@ ONNX_XFAIL_SET = {
|
|||
"LinalgNormModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.AveragePool
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AvgPool1dFloatModule_basic",
|
||||
"AvgPool1dIntModule_basic",
|
||||
"AvgPool1dStaticModule_basic",
|
||||
"AvgPool2dCeilModeTrueModule_basic",
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
"AvgPool2dFloatModule_basic",
|
||||
"AvgPool2dIntModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Cast
|
||||
"BucketizeTensorOutInt32RightModule_basic",
|
||||
|
|
Loading…
Reference in New Issue