From 22cd4441e7f87683e4be869cb3b333b18e645830 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 1 Aug 2024 11:37:53 +0800 Subject: [PATCH] [Torch] Add support for static uneven divisible AdaptiveAvgPool2d (#3566) The static uneven divisible AdaptiveAvgPool2d means that although the input size is not an integer multiple of ouput size, but the kernel and stride size can also be fixed (not dynamic). The derivation logic of kernel and stride size is consistent with torch/_decomp/decomposations.py:adaptive_avg_pool2d as described in the following: 1. Stride Size Firstly , derive the start index in each reduce operation according to the output size (`n`), `start_index = ([0, 1, ..., n - 1] * input_size) // output_size`. For each index `k`, if `k * (input_size % output_size) < output_size`, then the current and previous stride keeps the same as `input_size // output_size`. So suppose `(n-1) * (input_size % output_size) < output_size`, the stride in the whole AdaptiveAvgPool2d process keeps static, as `input_size // output_size`. 2. Kernel Size torch/_decomp/decomposations.py:adaptive_avg_pool2d calculates a static kernel size when the input/output sizes satisfy either of the two conditions, `input_size % output_size == 0` or `output_size % (input_size % output_size) == 0`. Here if `input_size % output_size == 0`, then the kernel size equals `input_size // output_size`, otherwise `input_size // output_size + 1.` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 14 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 72 +++++++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../build_tools/torch_ods_gen.py | 5 +- .../torch_mlir_e2e_test/test_suite/pooling.py | 23 ++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 31 -------- 7 files changed, 106 insertions(+), 44 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index aa2566711..53ac25077 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7729,6 +7729,7 @@ def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ca46ca62f..fb028e046 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4857,6 +4857,20 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// Aten_AdaptiveAvgPool2dOp +//===----------------------------------------------------------------------===// + +void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Aten_AdaptiveAvgPool2dOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getOutputSize()); + + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f073d1405..abb84dff4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7038,32 +7038,80 @@ class DecomposeAtenAdaptiveAvgPool2dOp getListConstructElements(outputShape, outputShapeSizesTorchInt); // TODO: Add support for cases other than: - // inH % outH != 0 or inW % outW != 0 - + // inH % outH != 0 or inW % outW != 0 where + // the stride/kernel size is not fixed. + // The following logic of stride/kernel size derivation is consistent + // with torch/_decomp/decomposations.py:adaptive_avg_pool2d. Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); Value constantFalse = rewriter.create(loc, false); Value constantTrue = rewriter.create(loc, true); Value constantNone = rewriter.create(loc); - SmallVector kernelSize; + SmallVector strideSize; + SmallVector kernelSize; for (unsigned i = 0; i < inputHW.size(); i++) { Value remainder = rewriter.create( loc, inputHW[i], outputShapeSizesTorchInt[i]); - Value cond = rewriter.create(loc, remainder, constantZero); - rewriter.create(loc, cond, - "unimplemented: only support cases " - "input size is an integer multiple of " - "output size"); - Value stride = rewriter.create( + + // Filter cases with fixed stride size. + Value cond1 = rewriter.create( + loc, outputShapeSizesTorchInt[i], + rewriter.create( + loc, remainder, + rewriter.create( + loc, outputShapeSizesTorchInt[i], constantOne))); + rewriter.create( + loc, cond1, + "unimplemented: only support cases with fixed stride size."); + + // Filter cases with fixed kernel size. + // cond2: whether input_size % output_size == 0. + Value cond2 = + rewriter.create(loc, remainder, constantZero); + // cond3: whether output_size % (input_size % output_size) == 0. + // To avoid potential crash (eg. tosa) happens,choose to mod 1 (add + // offset) when remainder equals 0, which has no side effect on + // effectiveness. + Value offset = rewriter.create( + loc, rewriter.create( + loc, rewriter.create(loc, remainder))); + Value remainder_not_zero = + rewriter.create(loc, remainder, offset); + Value cond3 = rewriter.create( + loc, + rewriter.create( + loc, outputShapeSizesTorchInt[i], remainder_not_zero), + constantZero); + Value cond = rewriter.create(loc, cond2, cond3); + + rewriter.create( + loc, cond, + "unimplemented: only support cases with fixed kernel size."); + + Value stride = rewriter.create( loc, inputHW[i], outputShapeSizesTorchInt[i]); - Value kernelSizeValue = stride; - kernelSize.push_back(kernelSizeValue); + strideSize.emplace_back(stride); + + Value kernel = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + + // When remainder equals 0, it is no need for kernel to add 1 + // and just keep the same as stride, otherwise it is necessary + // to add 1 (torch/_decomp/decomposations.py:adaptive_avg_pool2d). + Value boolMod = rewriter.create(loc, remainder); + Value intMod = rewriter.create(loc, boolMod); + + kernel = rewriter.create(loc, kernel, intMod); + kernelSize.emplace_back(kernel); } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - Value strideList = kernelSizeList; + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), strideSize); Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c54a9023b..a24840b29 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -853,6 +853,7 @@ STABLEHLO_PASS_SET = { "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AddIntModule_basic", "AliasModule_basic", "TrueFalseOrBoolOpModule_basic", @@ -1537,6 +1538,7 @@ TOSA_PASS_SET = { "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AddCDivModule_basic", "AddCDiv_Module_basic", "AddCMulModule_basic", @@ -2062,6 +2064,7 @@ MAKE_FX_TOSA_PASS_SET = ( "ViewNoChange1dModule_basic", "ViewNoChange2dModule_basic", "ViewNoChange3dModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", } LTC_CRASHING_SET = { @@ -2265,6 +2268,7 @@ ONNX_XFAIL_SET = { "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveAvgPool3dDynamicNoBatch_basic", "AdaptiveAvgPool3dDynamic_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 30758f457..7007de718 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -662,7 +662,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): ) emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") - emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit( + "aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)", + has_canonicalizer=True, + ) emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index ae26a7cef..6d36c6909 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -108,6 +108,29 @@ def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic( module.forward(tu.rand(1, 512, 15, 14)) +class AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((2, 2)) + + @export + @annotate_args( + [ + None, + ([1, 3, 7, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule() +) +def AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 7, 7)) + + class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 9b95ddc07..3ed9fcbfa 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -26,37 +26,6 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch } // ----- -// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input( -// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0 -// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3 -// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7 -// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true -// CHECK-DAG: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[REMAINER1:.*]] = torch.aten.remainder.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[REMAINER1]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases input size is an integer multiple of output size" -// CHECK: %[[STRIDE1:.*]] = torch.aten.floordiv.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[REMAINER2:.*]] = torch.aten.remainder.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[REMAINER2]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases input size is an integer multiple of output size" -// CHECK: %[[STRIDE2:.*]] = torch.aten.floordiv.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[STRIDE1]], %[[STRIDE2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[KERNEL_SIZE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> -func.func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { - %int7 = torch.constant.int 7 - %output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list - %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false