mirror of https://github.com/llvm/torch-mlir
[torch] Add integer support for pooling operations (#3610)
If we pass an integer type to the pooling operation we incorrectly pad with an integer value with causes downstream compilation failures.pull/3617/head
parent
7f2a17e757
commit
4350672685
|
@ -361,18 +361,29 @@ public:
|
|||
|
||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||
|
||||
TypedAttr smallestValueAttr;
|
||||
|
||||
if (auto fpty = dyn_cast<mlir::FloatType>(elementType)) {
|
||||
smallestValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true));
|
||||
} else if (auto intTy = dyn_cast<mlir::IntegerType>(elementType)) {
|
||||
int64_t bw = intTy.getIntOrFloatBitWidth();
|
||||
smallestValueAttr = rewriter.getIntegerAttr(
|
||||
elementType, intTy.isUnsigned() ? APInt::getMinValue(bw)
|
||||
: APInt::getSignedMinValue(bw));
|
||||
}
|
||||
|
||||
if (!smallestValueAttr)
|
||||
return rewriter.notifyMatchFailure(op, "invalid element type");
|
||||
|
||||
if constexpr (Dim == 1) {
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
Value maxPool1d, paddedInput;
|
||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getInf(
|
||||
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
|
||||
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
|
||||
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
|
||||
paddedInput, maxPool1d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
|
||||
Type newResultType = this->getTypeConverter()->convertType(op.getType());
|
||||
|
@ -382,15 +393,10 @@ public:
|
|||
SmallVector<Value, 4> outTensorShape;
|
||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
Value maxPool2d, paddedInput;
|
||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getInf(
|
||||
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
|
||||
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
|
||||
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
|
||||
paddedInput, maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
Type newResultType = this->getTypeConverter()->convertType(op.getType());
|
||||
|
|
Loading…
Reference in New Issue