mirror of https://github.com/llvm/torch-mlir
Support default padding case for tosa::AvgPool in the presence of count_include_pad (#3868)
Essentially, as part of my earlier
[change](7f9f99c6f8
)
, I didn't consider the `padding` value while erroring out for
unsupported `count_include_pad` during `torch-to-tosa` lowering for
AvgPool2d. The fix captured in this change addresses this. Please see
[issue](https://github.com/llvm/torch-mlir/issues/3862) for more details
on this.
Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
main
parent
cd38ecf6c2
commit
30c519369e
|
@ -5549,6 +5549,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
|||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||
paddingInts.push_back(0);
|
||||
|
||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
|
||||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||
// Currently, we can not represent `count_include_pad` with the existing
|
||||
// TOSA AvgPool2d specification. Without the below check, we produce silent
|
||||
// wrong answer (SWA) when the `count_include_pad` value is `true.`
|
||||
//
|
||||
// Note: We need to check for `count_include_pad` only when the `padding`
|
||||
// value is non-zero.
|
||||
bool countIncludePad;
|
||||
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
|
||||
(!matchPattern(op.getCountIncludePad(),
|
||||
m_TorchConstantBool(&countIncludePad)) ||
|
||||
|
||||
countIncludePad)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
|
||||
"`count_include_pad` value should be `False`.");
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
|
||||
paddingInts[1], paddingInts[1]};
|
||||
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
|
||||
|
@ -5677,18 +5697,6 @@ public:
|
|||
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||
Type &outputTy) const override {
|
||||
|
||||
// Currently, we can not represent `count_include_pad` with the existing
|
||||
// TOSA AvgPool2d specification. Without the below check, we produce silent
|
||||
// wrong answers (SWA) when the `count_include_pad` value is `true.`
|
||||
bool countIncludePad;
|
||||
if (!matchPattern(op.getCountIncludePad(),
|
||||
m_TorchConstantBool(&countIncludePad)) ||
|
||||
countIncludePad) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
|
||||
"`count_include_pad` value should be `False`.");
|
||||
}
|
||||
|
||||
// Currently, we can not represent `divisor_override` with the existing TOSA
|
||||
// AvgPool2d specification. Without the below check, we produce silent wrong
|
||||
// answers (SWA) when the `divisor_override` value is other than `None.`
|
||||
|
@ -5737,7 +5745,7 @@ public:
|
|||
// Expected a rank 3 input tensor
|
||||
if (selfTy.getRank() != 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input tensor for MaxPool1d should have rank 3");
|
||||
op, "Input tensor for AvgPool1d should have rank 3");
|
||||
|
||||
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
|
||||
SmallVector<int64_t> rank4Shape(selfShape);
|
||||
|
@ -5748,18 +5756,6 @@ public:
|
|||
selfTy.getElementType()),
|
||||
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
|
||||
|
||||
// Currently, we can not represent `count_include_pad` with the existing
|
||||
// TOSA AvgPool2d specification. Without the below check, we produce silent
|
||||
// wrong answers (SWA) when the `count_include_pad` value is `true.`
|
||||
bool countIncludePad;
|
||||
if (!matchPattern(op.getCountIncludePad(),
|
||||
m_TorchConstantBool(&countIncludePad)) ||
|
||||
countIncludePad) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
|
||||
"`count_include_pad` value should be `False`.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dilationArray{1, 1};
|
||||
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
|
||||
tosa::AvgPool2dOp>(
|
||||
|
|
|
@ -1736,6 +1736,12 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
|
@ -2316,6 +2322,7 @@ TOSA_PASS_SET = {
|
|||
"ReshapeExpandModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
"ReturnTwoTensorF32I64_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
|
@ -3869,26 +3876,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ViewSizeFromOtherTensor_basic",
|
||||
"VisionTransformerModule_basic",
|
||||
"ZerosLikeModule_falsePinMemory",
|
||||
# count_include_pad and divisor_override check in TOSA AvgPool2d
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
|
||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
"ResNet18Module_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
# Unexpected failures due to new PyTorch version update
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||
"AdaptiveAvgPool2dDynamic_basic",
|
||||
"CrossEntropyLossModule_basic",
|
||||
|
|
|
@ -2424,3 +2424,18 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
|
|||
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
|
||||
return %0 : !torch.vtensor<[2,12],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int3 = torch.constant.int 3
|
||||
%false = torch.constant.bool false
|
||||
%count_include_pad = torch.constant.bool true
|
||||
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
|
||||
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
|
||||
return %3 : !torch.vtensor<[1,512,10],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue