[Linalg] fix lowering reduce max with -inf (#2097)

pull/2103/head
Yuanqiang Liu 2023-05-09 00:17:49 +08:00 committed by GitHub
parent 11a91b9d14
commit ef6dae6ae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 10 deletions

View File

@ -170,9 +170,8 @@ public:
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getLargest(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
@ -248,9 +247,8 @@ public:
// `maxpool2d` contains the result of maxpool2d operation over the input.
auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getLargest(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true));
Value maxPool2d, paddedInput;
SmallVector<Value, 4> outTensorShape;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(

View File

@ -122,9 +122,9 @@ public:
loc,
rewriter.getFloatAttr(
inElementType,
APFloat::getLargest(
APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
true)));
/*Negative=*/true)));
} else {
fillValueMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
@ -213,7 +213,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
elementType,
APFloat::getLargest(
APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true)));
else if (elementType.isa<mlir::IntegerType>() &&

View File

@ -13,7 +13,7 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index