Support default padding case for tosa::AvgPool in the presence of count_include_pad (#3868)

Essentially, as part of my earlier
[change](7f9f99c6f8)
, I didn't consider the `padding` value while erroring out for
unsupported `count_include_pad` during `torch-to-tosa` lowering for
AvgPool2d. The fix captured in this change addresses this. Please see
[issue](https://github.com/llvm/torch-mlir/issues/3862) for more details
on this.

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
main
Hanumanth 2024-11-12 16:48:20 -05:00 committed by GitHub
parent cd38ecf6c2
commit 30c519369e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 40 deletions

View File

@ -5549,6 +5549,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
std::is_same<AtenOpT, AtenAvgPool1dOp>())
paddingInts.push_back(0);
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");
}
}
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
paddingInts[1], paddingInts[1]};
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
@ -5677,18 +5697,6 @@ public:
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
Type &outputTy) const override {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}
// Currently, we can not represent `divisor_override` with the existing TOSA
// AvgPool2d specification. Without the below check, we produce silent wrong
// answers (SWA) when the `divisor_override` value is other than `None.`
@ -5737,7 +5745,7 @@ public:
// Expected a rank 3 input tensor
if (selfTy.getRank() != 3)
return rewriter.notifyMatchFailure(
op, "Input tensor for MaxPool1d should have rank 3");
op, "Input tensor for AvgPool1d should have rank 3");
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
SmallVector<int64_t> rank4Shape(selfShape);
@ -5748,18 +5756,6 @@ public:
selfTy.getElementType()),
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answers (SWA) when the `count_include_pad` value is `true.`
bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||
countIncludePad) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
"`count_include_pad` value should be `False`.");
}
SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(

View File

@ -1736,6 +1736,12 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
@ -2316,6 +2322,7 @@ TOSA_PASS_SET = {
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ResNet18StaticModule_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
@ -3869,26 +3876,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# count_include_pad and divisor_override check in TOSA AvgPool2d
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
"MobilenetV3Module_basic",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"CrossEntropyLossModule_basic",

View File

@ -2424,3 +2424,18 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
return %0 : !torch.vtensor<[2,12],f32>
}
// -----
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}