[torch] Add integer support for pooling operations (#3610)

If we pass an integer type to the pooling operation we incorrectly pad
with an integer value with causes downstream compilation failures.
pull/3617/head
Rob Suderman 2024-08-07 21:42:10 -07:00 committed by GitHub
parent 7f2a17e757
commit 4350672685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 12 deletions

View File

@ -361,18 +361,29 @@ public:
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
TypedAttr smallestValueAttr;
if (auto fpty = dyn_cast<mlir::FloatType>(elementType)) {
smallestValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true));
} else if (auto intTy = dyn_cast<mlir::IntegerType>(elementType)) {
int64_t bw = intTy.getIntOrFloatBitWidth();
smallestValueAttr = rewriter.getIntegerAttr(
elementType, intTy.isUnsigned() ? APInt::getMinValue(bw)
: APInt::getSignedMinValue(bw));
}
if (!smallestValueAttr)
return rewriter.notifyMatchFailure(op, "invalid element type");
if constexpr (Dim == 1) {
SmallVector<Value, 4> outTensorShape;
Value maxPool1d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
paddedInput, maxPool1d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
@ -382,15 +393,10 @@ public:
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());