[torch] Support `!countIncludePad` when unpadded for average pool (#2836)

We do not support average pool when `countIncludePad is set to false.
However if the input is unpadded then the setting of the boolean is
unneeded. Extended use by checking if padding is zero before rejecting
the lowering.
pull/2847/head
Rob Suderman 2024-01-31 15:09:36 -08:00 committed by GitHub
parent 0114a570e3
commit 34f6948533
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 2 deletions

View File

@ -557,7 +557,10 @@ public:
m_TorchConstantBool(&countIncludePad))) m_TorchConstantBool(&countIncludePad)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "count_include_pad must be a constant"); op, "count_include_pad must be a constant");
if (!countIncludePad) {
// If the padding is zero then there is no padding to include.
if (!countIncludePad &&
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: count_include_pad is expected to be true"); op, "unimplemented: count_include_pad is expected to be true");
} }

View File

@ -847,6 +847,28 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
class AvgPool2dWithoutPadModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[0, 0],
ceil_mode=False,
count_include_pad=False,
divisor_override=None)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.ap2d(x)
@register_test_case(module_factory=lambda: AvgPool2dWithoutPadModule())
def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
# ============================================================================== # ==============================================================================
@ -1141,4 +1163,4 @@ class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module):
module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices())
def AdaptiveMaxPool2dStaticWithIndices_basic( def AdaptiveMaxPool2dStaticWithIndices_basic(
module, tu: TestUtils): module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10, 16)) module.forward(tu.rand(1, 512, 10, 16))