mirror of https://github.com/llvm/torch-mlir
[torch] Support `!countIncludePad` when unpadded for average pool (#2836)
We do not support average pool when `countIncludePad is set to false. However if the input is unpadded then the setting of the boolean is unneeded. Extended use by checking if padding is zero before rejecting the lowering.pull/2847/head
parent
0114a570e3
commit
34f6948533
|
@ -557,7 +557,10 @@ public:
|
|||
m_TorchConstantBool(&countIncludePad)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "count_include_pad must be a constant");
|
||||
if (!countIncludePad) {
|
||||
|
||||
// If the padding is zero then there is no padding to include.
|
||||
if (!countIncludePad &&
|
||||
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: count_include_pad is expected to be true");
|
||||
}
|
||||
|
|
|
@ -847,6 +847,28 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
|
|||
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
|
||||
|
||||
class AvgPool2dWithoutPadModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8],
|
||||
stride=[2, 2],
|
||||
padding=[0, 0],
|
||||
ceil_mode=False,
|
||||
count_include_pad=False,
|
||||
divisor_override=None)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool2dWithoutPadModule())
|
||||
def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -1141,4 +1163,4 @@ class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module):
|
|||
module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices())
|
||||
def AdaptiveMaxPool2dStaticWithIndices_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 10, 16))
|
||||
module.forward(tu.rand(1, 512, 10, 16))
|
||||
|
|
Loading…
Reference in New Issue