mirror of https://github.com/llvm/torch-mlir
[MLIR][Torch] Fix OnnxToLinalg lowering for AvgPool op (#3076)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3098/head
parent
282e9b0e64
commit
6844c84702
|
@ -310,7 +310,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"AveragePool", 19,
|
"AveragePool", 11,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
std::string autoPad;
|
std::string autoPad;
|
||||||
SmallVector<int64_t> dilation;
|
SmallVector<int64_t> dilation;
|
||||||
|
@ -361,7 +361,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op,
|
binder.op,
|
||||||
"padding list size does not match twice the number of axes");
|
"padding list size does not match twice the number of axes");
|
||||||
}
|
}
|
||||||
if (binder.s64IntegerArrayAttr(strides, "strides", {1})) {
|
if (binder.s64IntegerArrayAttr(
|
||||||
|
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (strides.size() != 1 && strides.size() != rank - 2) {
|
if (strides.size() != 1 && strides.size() != rank - 2) {
|
||||||
|
|
|
@ -114,8 +114,22 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
|
||||||
SmallVectorImpl<int64_t> &paddingInts,
|
SmallVectorImpl<int64_t> &paddingInts,
|
||||||
Value initValue) {
|
Value initValue) {
|
||||||
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
||||||
|
SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
|
||||||
|
|
||||||
|
unsigned selfRank = self.getType().cast<RankedTensorType>().getRank();
|
||||||
|
unsigned paddingIntsSize = paddingInts.size();
|
||||||
|
|
||||||
|
if (paddingIntsSize == 2 * (selfRank - 2)) {
|
||||||
|
// This condition being true means that the `paddingInts` contain seperate
|
||||||
|
// values for low padding and high padding.
|
||||||
|
for (unsigned i = 0; i < paddingIntsSize / 2; i++)
|
||||||
|
lowPaddingIncludingNC.push_back(paddingInts[i]);
|
||||||
|
for (unsigned i = paddingIntsSize / 2; i < paddingIntsSize; i++)
|
||||||
|
highPaddingIncludingNC.push_back(paddingInts[i]);
|
||||||
|
} else {
|
||||||
lowPaddingIncludingNC.append(paddingInts);
|
lowPaddingIncludingNC.append(paddingInts);
|
||||||
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;
|
highPaddingIncludingNC = lowPaddingIncludingNC;
|
||||||
|
}
|
||||||
|
|
||||||
if (ceilMode) {
|
if (ceilMode) {
|
||||||
for (int64_t i = 0; i < dimensionality; ++i) {
|
for (int64_t i = 0; i < dimensionality; ++i) {
|
||||||
|
|
|
@ -1908,18 +1908,8 @@ ONNX_XFAIL_SET = {
|
||||||
"LinalgNormModule_basic",
|
"LinalgNormModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.AveragePool
|
# Failure - onnx_lowering: onnx.AveragePool
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
|
||||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||||
"AvgPool1dFloatModule_basic",
|
|
||||||
"AvgPool1dIntModule_basic",
|
|
||||||
"AvgPool1dStaticModule_basic",
|
|
||||||
"AvgPool2dCeilModeTrueModule_basic",
|
|
||||||
"AvgPool2dDivisorOverrideModule_basic",
|
"AvgPool2dDivisorOverrideModule_basic",
|
||||||
"AvgPool2dFloatModule_basic",
|
|
||||||
"AvgPool2dIntModule_basic",
|
|
||||||
"AvgPool2dStaticModule_basic",
|
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Cast
|
# Failure - onnx_lowering: onnx.Cast
|
||||||
"BucketizeTensorOutInt32RightModule_basic",
|
"BucketizeTensorOutInt32RightModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue