mirror of https://github.com/llvm/torch-mlir
[Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566)
The static uneven divisible AdaptiveAvgPool2d means that although the input size is not an integer multiple of ouput size, but the kernel and stride size can also be fixed (not dynamic). The derivation logic of kernel and stride size is consistent with torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the following: 1. Stride Size Firstly , derive the start index in each reduce operation according to the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size) // output_size`. For each index `k`, if `k * (input_size % output_size) < output_size`, then the current and previous stride keeps the same as `input_size // output_size`. So suppose `(n-1) * (input_size % output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d process keeps static, as `input_size // output_size`. 2. Kernel Size torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static kernel size when the input/output sizes satisfy either of the two conditions, `input_size % output_size == 0` or `output_size % (input_size % output_size) == 0`. Here if `input_size % output_size == 0`, then the kernel size equals `input_size // output_size`, otherwise `input_size // output_size + 1.`pull/3516/head
parent
6f7a5db801
commit
22cd4441e7
|
@ -7729,6 +7729,7 @@ def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [
|
||||
|
|
|
@ -4857,6 +4857,20 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Aten_AdaptiveAvgPool2dOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](Aten_AdaptiveAvgPool2dOp op, PatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<AtenAdaptiveAvgPool2dOp>(
|
||||
op, op.getType(), op.getSelf(), op.getOutputSize());
|
||||
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenLinalgCrossOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -7038,32 +7038,80 @@ class DecomposeAtenAdaptiveAvgPool2dOp
|
|||
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
||||
|
||||
// TODO: Add support for cases other than:
|
||||
// inH % outH != 0 or inW % outW != 0
|
||||
|
||||
// inH % outH != 0 or inW % outW != 0 where
|
||||
// the stride/kernel size is not fixed.
|
||||
// The following logic of stride/kernel size derivation is consistent
|
||||
// with torch/_decomp/decomposations.py:adaptive_avg_pool2d.
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
SmallVector<Value, 2> kernelSize;
|
||||
|
||||
SmallVector<Value, 2> strideSize;
|
||||
SmallVector<Value, 2> kernelSize;
|
||||
for (unsigned i = 0; i < inputHW.size(); i++) {
|
||||
Value remainder = rewriter.create<AtenRemainderIntOp>(
|
||||
loc, inputHW[i], outputShapeSizesTorchInt[i]);
|
||||
Value cond = rewriter.create<AtenEqIntOp>(loc, remainder, constantZero);
|
||||
rewriter.create<RuntimeAssertOp>(loc, cond,
|
||||
"unimplemented: only support cases "
|
||||
"input size is an integer multiple of "
|
||||
"output size");
|
||||
Value stride = rewriter.create<AtenFloordivIntOp>(
|
||||
|
||||
// Filter cases with fixed stride size.
|
||||
Value cond1 = rewriter.create<Torch::AtenGtIntOp>(
|
||||
loc, outputShapeSizesTorchInt[i],
|
||||
rewriter.create<Torch::AtenMulIntOp>(
|
||||
loc, remainder,
|
||||
rewriter.create<Torch::AtenSubIntOp>(
|
||||
loc, outputShapeSizesTorchInt[i], constantOne)));
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond1,
|
||||
"unimplemented: only support cases with fixed stride size.");
|
||||
|
||||
// Filter cases with fixed kernel size.
|
||||
// cond2: whether input_size % output_size == 0.
|
||||
Value cond2 =
|
||||
rewriter.create<Torch::AtenEqIntOp>(loc, remainder, constantZero);
|
||||
// cond3: whether output_size % (input_size % output_size) == 0.
|
||||
// To avoid potential crash (eg. tosa) happens,choose to mod 1 (add
|
||||
// offset) when remainder equals 0, which has no side effect on
|
||||
// effectiveness.
|
||||
Value offset = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
loc, rewriter.create<Torch::Aten__Not__Op>(
|
||||
loc, rewriter.create<Torch::AtenBoolIntOp>(loc, remainder)));
|
||||
Value remainder_not_zero =
|
||||
rewriter.create<Torch::AtenAddIntOp>(loc, remainder, offset);
|
||||
Value cond3 = rewriter.create<Torch::AtenEqIntOp>(
|
||||
loc,
|
||||
rewriter.create<Torch::AtenRemainderIntOp>(
|
||||
loc, outputShapeSizesTorchInt[i], remainder_not_zero),
|
||||
constantZero);
|
||||
Value cond = rewriter.create<Torch::Aten__Or__BoolOp>(loc, cond2, cond3);
|
||||
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond,
|
||||
"unimplemented: only support cases with fixed kernel size.");
|
||||
|
||||
Value stride = rewriter.create<Torch::AtenFloordivIntOp>(
|
||||
loc, inputHW[i], outputShapeSizesTorchInt[i]);
|
||||
Value kernelSizeValue = stride;
|
||||
kernelSize.push_back(kernelSizeValue);
|
||||
strideSize.emplace_back(stride);
|
||||
|
||||
Value kernel = rewriter.create<Torch::AtenFloordivIntOp>(
|
||||
loc, inputHW[i], outputShapeSizesTorchInt[i]);
|
||||
|
||||
// When remainder equals 0, it is no need for kernel to add 1
|
||||
// and just keep the same as stride, otherwise it is necessary
|
||||
// to add 1 (torch/_decomp/decomposations.py:adaptive_avg_pool2d).
|
||||
Value boolMod = rewriter.create<Torch::AtenBoolIntOp>(loc, remainder);
|
||||
Value intMod = rewriter.create<Torch::AtenIntBoolOp>(loc, boolMod);
|
||||
|
||||
kernel = rewriter.create<Torch::AtenAddIntOp>(loc, kernel, intMod);
|
||||
kernelSize.emplace_back(kernel);
|
||||
}
|
||||
|
||||
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
||||
Value strideList = kernelSizeList;
|
||||
Value strideList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), strideSize);
|
||||
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantZero, constantZero});
|
||||
|
|
|
@ -853,6 +853,7 @@ STABLEHLO_PASS_SET = {
|
|||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"AliasModule_basic",
|
||||
"TrueFalseOrBoolOpModule_basic",
|
||||
|
@ -1537,6 +1538,7 @@ TOSA_PASS_SET = {
|
|||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AddCDivModule_basic",
|
||||
"AddCDiv_Module_basic",
|
||||
"AddCMulModule_basic",
|
||||
|
@ -2062,6 +2064,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
"ViewNoChange1dModule_basic",
|
||||
"ViewNoChange2dModule_basic",
|
||||
"ViewNoChange3dModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
}
|
||||
|
||||
LTC_CRASHING_SET = {
|
||||
|
@ -2265,6 +2268,7 @@ ONNX_XFAIL_SET = {
|
|||
"AdaptiveAvgPool2dDynamic_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool3dDynamicNoBatch_basic",
|
||||
"AdaptiveAvgPool3dDynamic_basic",
|
||||
"AdaptiveMaxPool1dDynamicNoBatch_basic",
|
||||
|
|
|
@ -662,7 +662,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
)
|
||||
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)",
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)")
|
||||
|
|
|
@ -108,6 +108,29 @@ def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic(
|
|||
module.forward(tu.rand(1, 512, 15, 14))
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap2d = torch.nn.AdaptiveAvgPool2d((2, 2))
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 3, 7, 7], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule()
|
||||
)
|
||||
def AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 7, 7))
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -26,37 +26,6 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
|
|||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[REMAINER1:.*]] = torch.aten.remainder.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[REMAINER1]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases input size is an integer multiple of output size"
|
||||
// CHECK: %[[STRIDE1:.*]] = torch.aten.floordiv.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[REMAINER2:.*]] = torch.aten.remainder.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[REMAINER2]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases input size is an integer multiple of output size"
|
||||
// CHECK: %[[STRIDE2:.*]] = torch.aten.floordiv.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[STRIDE1]], %[[STRIDE2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[KERNEL_SIZE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int7 = torch.constant.int 7
|
||||
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
|
|
Loading…
Reference in New Issue