From 43506726853b35ae9c253aa1d1c61b76ad9b4c13 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 7 Aug 2024 21:42:10 -0700 Subject: [PATCH] [torch] Add integer support for pooling operations (#3610) If we pass an integer type to the pooling operation we incorrectly pad with an integer value with causes downstream compilation failures. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 30 ++++++++++++++---------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ae1717bc2..bb19d403e 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -361,18 +361,29 @@ public: Type elementType = cast(self.getType()).getElementType(); + TypedAttr smallestValueAttr; + + if (auto fpty = dyn_cast(elementType)) { + smallestValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true)); + } else if (auto intTy = dyn_cast(elementType)) { + int64_t bw = intTy.getIntOrFloatBitWidth(); + smallestValueAttr = rewriter.getIntegerAttr( + elementType, intTy.isUnsigned() ? APInt::getMinValue(bw) + : APInt::getSignedMinValue(bw)); + } + + if (!smallestValueAttr) + return rewriter.notifyMatchFailure(op, "invalid element type"); + if constexpr (Dim == 1) { SmallVector outTensorShape; Value maxPool1d, paddedInput; - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf( - cast(elementType).getFloatSemantics(), - /*Negative=*/true)); if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, /*dimensionality=*/1, kernelSizeIntValues, strideInts, - paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, + paddingInts, dilationInts, smallestValueAttr, outTensorShape, paddedInput, maxPool1d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d"); Type newResultType = this->getTypeConverter()->convertType(op.getType()); @@ -382,15 +393,10 @@ public: SmallVector outTensorShape; // `maxpool2d` contains the result of maxpool2d operation over the input. Value maxPool2d, paddedInput; - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf( - cast(elementType).getFloatSemantics(), - /*Negative=*/true)); if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, /*dimensionality=*/2, kernelSizeIntValues, strideInts, - paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, + paddingInts, dilationInts, smallestValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Type newResultType = this->getTypeConverter()->convertType(op.getType());