From 7f9f99c6f8c84323d896b47fcd67c4bc668f6577 Mon Sep 17 00:00:00 2001 From: Hanumanth Date: Fri, 1 Nov 2024 08:25:59 -0400 Subject: [PATCH] Fix torchToTosa lowering for avgpool2d to handle unsupported parameters (#3822) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The existing TorchToTosa lowering logic for `torch.aten.avg_pool2d` doesn't handle some unsupported properties well, leading to a silent wrong answer(SWA) when we go through `torch-backend-to-tosa-backend-pipeline.` For instance, with the existing TOSA avgpool2d specification, we can not represent `count_include_pad` and `divisor_override,` so during TorchToTosa lowering, we silently ignore these properties which leads to SWA in some cases—the fix captured in this change errors for unsupported scenarios. For details on `count_include_pad` and `divisor_override,` please see the below link. https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html --------- Co-authored-by: Hanumanth Hanumantharayappa --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 22 ++++++++ test/Conversion/TorchToTosa/basic.mlir | 66 ++++++++++++++++------ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ce8351ea9..48c38b077 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5466,6 +5466,28 @@ public: DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { + + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answers (SWA) when the `count_include_pad` value is `true.` + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + countIncludePad) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " + "`count_include_pad` value should be `False`."); + } + + // Currently, we can not represent `divisor_override` with the existing TOSA + // AvgPool2d specification. Without the below check, we produce silent wrong + // answers (SWA) when the `divisor_override` value is other than `None.` + if (!isa(op.getDivisorOverride().getType())) { + return rewriter.notifyMatchFailure( + op, "Unsupported `divisor_override` value, for tosa AvgPool2dOp " + "`divisor_override` value should be `None`."); + } + SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters( diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 80dcc0ac7..2cf2486e7 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -852,37 +852,35 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // ----- // CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 7 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false -// CHECK: %[[VAL_6:.*]] = torch.constant.bool true -// CHECK: %[[VAL_7:.*]] = torch.constant.none -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> -// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> -// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_10]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,1,1],f32> // CHECK: } func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> { %int7 = torch.constant.int 7 %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %false = torch.constant.bool false - %true = torch.constant.bool true %none = torch.constant.none %kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> + %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> return %0 : !torch.vtensor<[1,512,1,1],f32> } @@ -2001,6 +1999,42 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // ----- +func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false= torch.constant.bool false + %count_include_pad = torch.constant.bool true + %divisor_override = torch.constant.none + + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32> + return %3 : !torch.vtensor<[1,192,35,35],f32> +} + +// ----- + +func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false= torch.constant.bool false + %count_include_pad = torch.constant.bool false + %divisor_override = torch.constant.int 9 + + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> + return %3 : !torch.vtensor<[1,192,35,35],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { // CHECK: %[[VAL_0:.*]] = torch.constant.int 0 // CHECK: %[[VAL_1:.*]] = torch.constant.bool false