mirror of https://github.com/llvm/torch-mlir
Fix torchToTosa lowering for avgpool2d to handle unsupported parameters (#3822)
The existing TorchToTosa lowering logic for `torch.aten.avg_pool2d` doesn't handle some unsupported properties well, leading to a silent wrong answer(SWA) when we go through `torch-backend-to-tosa-backend-pipeline.` For instance, with the existing TOSA avgpool2d specification, we can not represent `count_include_pad` and `divisor_override,` so during TorchToTosa lowering, we silently ignore these properties which leads to SWA in some cases—the fix captured in this change errors for unsupported scenarios. For details on `count_include_pad` and `divisor_override,` please see the below link. https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html --------- Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>pull/3848/head
parent
032a636c35
commit
7f9f99c6f8
|
@ -5466,6 +5466,28 @@ public:
|
||||||
DenseI64ArrayAttr &kernel,
|
DenseI64ArrayAttr &kernel,
|
||||||
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
Type &outputTy) const override {
|
Type &outputTy) const override {
|
||||||
|
|
||||||
|
// Currently, we can not represent `count_include_pad` with the existing
|
||||||
|
// TOSA AvgPool2d specification. Without the below check, we produce silent
|
||||||
|
// wrong answers (SWA) when the `count_include_pad` value is `true.`
|
||||||
|
bool countIncludePad;
|
||||||
|
if (!matchPattern(op.getCountIncludePad(),
|
||||||
|
m_TorchConstantBool(&countIncludePad)) ||
|
||||||
|
countIncludePad) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
|
||||||
|
"`count_include_pad` value should be `False`.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Currently, we can not represent `divisor_override` with the existing TOSA
|
||||||
|
// AvgPool2d specification. Without the below check, we produce silent wrong
|
||||||
|
// answers (SWA) when the `divisor_override` value is other than `None.`
|
||||||
|
if (!isa<Torch::NoneType>(op.getDivisorOverride().getType())) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unsupported `divisor_override` value, for tosa AvgPool2dOp "
|
||||||
|
"`divisor_override` value should be `None`.");
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dilationArray{1, 1};
|
SmallVector<int64_t, 2> dilationArray{1, 1};
|
||||||
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
|
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
|
||||||
tosa::AvgPool2dOp>(
|
tosa::AvgPool2dOp>(
|
||||||
|
|
|
@ -858,31 +858,29 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
|
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
|
// CHECK: %[[VAL_6:.*]] = torch.constant.none
|
||||||
// CHECK: %[[VAL_7:.*]] = torch.constant.none
|
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_10]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
|
||||||
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
|
// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
|
||||||
// CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
|
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
|
||||||
// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
|
// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
|
||||||
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
|
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
|
||||||
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
|
// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,1,1],f32>
|
||||||
// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32>
|
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
|
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
|
||||||
%int7 = torch.constant.int 7
|
%int7 = torch.constant.int 7
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
%true = torch.constant.bool true
|
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
%kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
%kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
|
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
|
||||||
return %0 : !torch.vtensor<[1,512,1,1],f32>
|
return %0 : !torch.vtensor<[1,512,1,1],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2001,6 +1999,42 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%false= torch.constant.bool false
|
||||||
|
%count_include_pad = torch.constant.bool true
|
||||||
|
%divisor_override = torch.constant.none
|
||||||
|
|
||||||
|
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
|
||||||
|
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
|
||||||
|
return %3 : !torch.vtensor<[1,192,35,35],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%false= torch.constant.bool false
|
||||||
|
%count_include_pad = torch.constant.bool false
|
||||||
|
%divisor_override = torch.constant.int 9
|
||||||
|
|
||||||
|
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
|
||||||
|
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32>
|
||||||
|
return %3 : !torch.vtensor<[1,192,35,35],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
|
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
|
||||||
// CHECK: %[[VAL_0:.*]] = torch.constant.int 0
|
// CHECK: %[[VAL_0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
|
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
|
||||||
|
|
Loading…
Reference in New Issue