[Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566)

The static uneven divisible AdaptiveAvgPool2d means that although the
input size is not an integer multiple of ouput size, but the kernel and
stride size can also be fixed (not dynamic). The derivation logic of
kernel and stride size is consistent with
torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the
following:

1. Stride Size
Firstly , derive the start index in each reduce operation according to
the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size)
// output_size`. For each index `k`, if `k * (input_size % output_size)
< output_size`, then the current and previous stride keeps the same as
`input_size // output_size`. So suppose `(n-1) * (input_size %
output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d
process keeps static, as `input_size // output_size`.

2. Kernel Size
torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static
kernel size when the input/output sizes satisfy either of the two
conditions, `input_size % output_size == 0` or `output_size %
(input_size % output_size) == 0`. Here if `input_size % output_size ==
0`, then the kernel size equals `input_size // output_size`, otherwise
`input_size // output_size + 1.`
pull/3516/head
yyp0 2024-08-01 11:37:53 +08:00 committed by GitHub
parent 6f7a5db801
commit 22cd4441e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 106 additions and 44 deletions

View File

@ -7729,6 +7729,7 @@ def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [

View File

@ -4857,6 +4857,20 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// Aten_AdaptiveAvgPool2dOp
//===----------------------------------------------------------------------===//
void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](Aten_AdaptiveAvgPool2dOp op, PatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<AtenAdaptiveAvgPool2dOp>(
op, op.getType(), op.getSelf(), op.getOutputSize());
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenLinalgCrossOp
//===----------------------------------------------------------------------===//

View File

@ -7038,32 +7038,80 @@ class DecomposeAtenAdaptiveAvgPool2dOp
getListConstructElements(outputShape, outputShapeSizesTorchInt);
// TODO: Add support for cases other than:
// inH % outH != 0 or inW % outW != 0
// inH % outH != 0 or inW % outW != 0 where
// the stride/kernel size is not fixed.
// The following logic of stride/kernel size derivation is consistent
// with torch/_decomp/decomposations.py:adaptive_avg_pool2d.
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
SmallVector<Value, 2> kernelSize;
SmallVector<Value, 2> strideSize;
SmallVector<Value, 2> kernelSize;
for (unsigned i = 0; i < inputHW.size(); i++) {
Value remainder = rewriter.create<AtenRemainderIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
Value cond = rewriter.create<AtenEqIntOp>(loc, remainder, constantZero);
rewriter.create<RuntimeAssertOp>(loc, cond,
"unimplemented: only support cases "
"input size is an integer multiple of "
"output size");
Value stride = rewriter.create<AtenFloordivIntOp>(
// Filter cases with fixed stride size.
Value cond1 = rewriter.create<Torch::AtenGtIntOp>(
loc, outputShapeSizesTorchInt[i],
rewriter.create<Torch::AtenMulIntOp>(
loc, remainder,
rewriter.create<Torch::AtenSubIntOp>(
loc, outputShapeSizesTorchInt[i], constantOne)));
rewriter.create<RuntimeAssertOp>(
loc, cond1,
"unimplemented: only support cases with fixed stride size.");
// Filter cases with fixed kernel size.
// cond2: whether input_size % output_size == 0.
Value cond2 =
rewriter.create<Torch::AtenEqIntOp>(loc, remainder, constantZero);
// cond3: whether output_size % (input_size % output_size) == 0.
// To avoid potential crash (eg. tosa) happens,choose to mod 1 (add
// offset) when remainder equals 0, which has no side effect on
// effectiveness.
Value offset = rewriter.create<Torch::AtenIntBoolOp>(
loc, rewriter.create<Torch::Aten__Not__Op>(
loc, rewriter.create<Torch::AtenBoolIntOp>(loc, remainder)));
Value remainder_not_zero =
rewriter.create<Torch::AtenAddIntOp>(loc, remainder, offset);
Value cond3 = rewriter.create<Torch::AtenEqIntOp>(
loc,
rewriter.create<Torch::AtenRemainderIntOp>(
loc, outputShapeSizesTorchInt[i], remainder_not_zero),
constantZero);
Value cond = rewriter.create<Torch::Aten__Or__BoolOp>(loc, cond2, cond3);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases with fixed kernel size.");
Value stride = rewriter.create<Torch::AtenFloordivIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
Value kernelSizeValue = stride;
kernelSize.push_back(kernelSizeValue);
strideSize.emplace_back(stride);
Value kernel = rewriter.create<Torch::AtenFloordivIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
// When remainder equals 0, it is no need for kernel to add 1
// and just keep the same as stride, otherwise it is necessary
// to add 1 (torch/_decomp/decomposations.py:adaptive_avg_pool2d).
Value boolMod = rewriter.create<Torch::AtenBoolIntOp>(loc, remainder);
Value intMod = rewriter.create<Torch::AtenIntBoolOp>(loc, boolMod);
kernel = rewriter.create<Torch::AtenAddIntOp>(loc, kernel, intMod);
kernelSize.emplace_back(kernel);
}
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
Value strideList = kernelSizeList;
Value strideList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), strideSize);
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantZero, constantZero});

View File

@ -853,6 +853,7 @@ STABLEHLO_PASS_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AddIntModule_basic",
"AliasModule_basic",
"TrueFalseOrBoolOpModule_basic",
@ -1537,6 +1538,7 @@ TOSA_PASS_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AddCDivModule_basic",
"AddCDiv_Module_basic",
"AddCMulModule_basic",
@ -2062,6 +2064,7 @@ MAKE_FX_TOSA_PASS_SET = (
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
}
LTC_CRASHING_SET = {
@ -2265,6 +2268,7 @@ ONNX_XFAIL_SET = {
"AdaptiveAvgPool2dDynamic_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool3dDynamicNoBatch_basic",
"AdaptiveAvgPool3dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic",

View File

@ -662,7 +662,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
)
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit(
"aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)",
has_canonicalizer=True,
)
emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")

View File

@ -108,6 +108,29 @@ def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic(
module.forward(tu.rand(1, 512, 15, 14))
class AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((2, 2))
@export
@annotate_args(
[
None,
([1, 3, 7, 7], torch.float32, True),
]
)
def forward(self, x):
return self.aap2d(x)
@register_test_case(
module_factory=lambda: AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule()
)
def AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 7, 7))
class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -26,37 +26,6 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
}
// -----
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[REMAINER1:.*]] = torch.aten.remainder.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[REMAINER1]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases input size is an integer multiple of output size"
// CHECK: %[[STRIDE1:.*]] = torch.aten.floordiv.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[REMAINER2:.*]] = torch.aten.remainder.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[REMAINER2]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases input size is an integer multiple of output size"
// CHECK: %[[STRIDE2:.*]] = torch.aten.floordiv.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[STRIDE1]], %[[STRIDE2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[KERNEL_SIZE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int7 = torch.constant.int 7
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false