diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index d0ff6e973..dcb28129a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -441,9 +441,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - for (int64_t i : padding) { + // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] + // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all + // axes x. + int64_t paddingSizeHalf = padding.size() / 2; + for (int64_t i = 0; i < paddingSizeHalf; ++i) { + // Check if onnx padding attribute is symmetric. + if (padding[i] != padding[i + paddingSizeHalf]) + return rewriter.notifyMatchFailure( + binder.op, "onnx padding attribute is not symmetric"); cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 36fa9dc56..d80f3d427 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -619,13 +619,6 @@ public: return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); - // If the padding is zero then there is no padding to include. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true"); - } - // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; SmallVector outTensorShape; @@ -635,9 +628,142 @@ public: paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); - // } - Value divisor = kernelSizeIntValues[0]; + // Compute the average of sumPool. + Value outputTensor = rewriter.create( + loc, getAsOpFoldResult(outTensorShape), resultElementType); + SmallVector indexingMapsAvg( + 2, rewriter.getMultiDimIdentityMap(Dim + 2)); + SmallVector iteratorTypesAvg( + Dim + 2, utils::IteratorType::parallel); + Value avgPool; + Value divisor; + // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. + if constexpr (std::is_same()) { + auto selfType = cast(self.getType()); + const int64_t selfRank = selfType.getRank(); + int64_t wDim = toPositiveDim(-1, selfRank); + int64_t hDim = toPositiveDim(-2, selfRank); + Value inputHeight = getDimOp(rewriter, loc, self, hDim); + Value inputWidth = getDimOp(rewriter, loc, self, wDim); + RankedTensorType sumPoolType = cast(sumPool.getType()); + const int64_t rank = sumPoolType.getRank(); + int dimH = toPositiveDim(-2, rank); + int dimW = toPositiveDim(-1, rank); + avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + // The algorithm for computing the divisor with + // count_include_pad is manily based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + Value indexOh = + b.create(loc, /*value=*/dimH); + Value oh = castIndexToInt64(b, loc, indexOh); + Value indexOw = + b.create(loc, /*value=*/dimW); + Value ow = castIndexToInt64(b, loc, indexOw); + + // int64_t ih0 = oh * dH - padH; + Value dH = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[0])); + Value padH = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[0])); + Value ohDH = b.create(loc, oh, dH); + Value ih0 = b.create(loc, ohDH, padH); + // int64_t iw0 = ow * dW - padW; + Value dW = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[1])); + Value padW = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[1])); + Value owDW = b.create(loc, ow, dW); + Value iw0 = b.create(loc, owDW, padW); + // int64_t ih1 = std::min(ih0 + kH, input_height + padH); + Value ih = castIndexToInt64(b, loc, inputHeight); + Value ih0KH = b.create( + loc, ih0, kernelSizeIntValues[0]); + Value ihPadH = b.create(loc, ih, padH); + Value ih1 = b.create(loc, ih0KH, ihPadH); + // int64_t iw1 = std::min(iw0 + kW, input_width + padW); + Value iw = castIndexToInt64(b, loc, inputWidth); + Value iw0KW = b.create( + loc, iw0, kernelSizeIntValues[1]); + Value iwPadW = b.create(loc, iw, padW); + Value iw1 = b.create(loc, iw0KW, iwPadW); + // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); + Value ih1Ih0 = b.create(loc, ih1, ih0); + Value iw1Iw0 = b.create(loc, iw1, iw0); + Value poolSize = + b.create(loc, ih1Ih0, iw1Iw0); + // ih0 = std::max(ih0, 0); + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value ih0Clamped = + b.create(loc, ih0, cstZero); + // iw0 = std::max(iw0, 0); + Value iw0Clamped = + b.create(loc, iw0, cstZero); + // ih1 = std::min(ih1, input_height); + Value ih1Clamped = b.create(loc, ih1, ih); + // iw1 = std::min(iw1, input_width); + Value iw1Clamped = b.create(loc, iw1, iw); + // if (divisor_override.has_value()) { + // divisor = divisor_override.value(); + // } else { + // if(count_include_pad) { + // divisor = pool_size; + // } else { + // divisor = (ih1 - ih0) * (iw1 - iw0); + // } + // } + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = + b.create(loc, ih1Clamped, ih0Clamped); + Value iw1_iw0 = + b.create(loc, iw1Clamped, iw0Clamped); + divisor = b.create(loc, ih1_ih0, iw1_iw0); + } + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + if constexpr (std::is_same()) { + if (!isa( + op.getDivisorOverride().getType())) + divisor = adaptor.getDivisorOverride(); + } + + divisor = convertScalarToDtype(b, loc, divisor, + resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); + } + + // TODO: Add support for count_include_pad equal to `False` in + // AtenAvgPool1/3dOp. + if (!countIncludePad && + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { + return rewriter.notifyMatchFailure( + op, "unimplemented: count_include_pad is expected to be true for " + "AtenAvgPool3dOp"); + } + + // Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`. + divisor = kernelSizeIntValues[0]; for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { divisor = rewriter.create(loc, divisor, kernelSizeIntValues[i]); @@ -648,29 +774,20 @@ public: : adaptor.getDivisorOverride(); } divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); - - Value outputTensor = rewriter.create( - loc, getAsOpFoldResult(outTensorShape), resultElementType); - SmallVector indexingMapsAvg( - 2, rewriter.getMultiDimIdentityMap(Dim + 2)); - SmallVector iteratorTypesAvg( - Dim + 2, utils::IteratorType::parallel); - Value avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - + avgPool = rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 33dd2c082..40781bef3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -888,6 +888,7 @@ STABLEHLO_PASS_SET = { "Aten_CastLongModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", @@ -1479,6 +1480,7 @@ STABLEHLO_CRASHING_SET = set() # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AvgPool2dCountIncludePadFalseStaticModule_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", "AtenLinear2D_basic", @@ -1950,6 +1952,7 @@ MAKE_FX_TOSA_PASS_SET = ( TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "AvgPool2dCountIncludePadFalseStaticModule_basic", "AtenLinear1D_basic", "AtenLinearMatVec_basic", "AtenLinearVecMatBias_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index bbcfd15d9..1de40096c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1017,6 +1017,35 @@ def AvgPool2dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 10, 20, low=-1)) +class AvgPool2dCountIncludePadFalseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([32, 384, 25, 25], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCountIncludePadFalseStaticModule()) +def AvgPool2dCountIncludePadFalseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(32, 384, 25, 25, low=-1)) + + class AvgPool2dDivisorOverrideModule(torch.nn.Module): def __init__(self): super().__init__()