g++ build fix (#2778)

Introduced in 704cfdaf08 of @wu-s-john 

g++ compiler error: 

Pooling.cpp:177:13: error: explicit specialization in non-namespace
scope ‘class

Design looks good, g++ is just freaking out for no good reason.
Un-nesting the template classes fixes the error.

We don't have g++ CI. This hopefully happens infrequently enough that we
can just fix manually. My service to those folks who really like
building with g++... :)
pull/2776/head snapshot-20240120.1089
James Newling 2024-01-19 19:12:29 -08:00 committed by GitHub
parent 2f4924015d
commit 50ac3b1912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 15 deletions

View File

@ -72,7 +72,8 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
return success();
}
static Value computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
static Value
computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
Value self, int64_t dimensionality, bool ceilMode,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
@ -167,21 +168,26 @@ static LogicalResult createPoolingOp(
}
namespace {
template <typename T> struct DimensionTraits {};
template <> struct DimensionTraits<AtenMaxPool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};
template <> struct DimensionTraits<AtenMaxPool3dOp> {
static constexpr int64_t Dim = 3;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};
template <typename OpTy>
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
private:
template <typename T> struct DimensionTraits;
template <> struct DimensionTraits<AtenMaxPool2dOp> {
static const int64_t Dim = 2;
};
template <> struct DimensionTraits<AtenMaxPool3dOp> {
static const int64_t Dim = 3;
};
static const int64_t Dim = DimensionTraits<OpTy>::Dim;
LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op,
@ -327,9 +333,9 @@ public:
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
return success();
} else {
return createPoolingMax3D(op, adaptor, rewriter,
kernelSizeIntValues, strideInts, paddingInts,
dilationInts, ceilMode);
return createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues,
strideInts, paddingInts, dilationInts,
ceilMode);
}
}
};