[Linalg] fix lowering reduce max with -inf (#2097)

pull/2103/head
Yuanqiang Liu 2023-05-09 00:17:49 +08:00 committed by GitHub
parent 11a91b9d14
commit ef6dae6ae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 10 deletions

View File

@ -170,8 +170,7 @@ public:
Type elementType = self.getType().cast<RankedTensorType>().getElementType(); Type elementType = self.getType().cast<RankedTensorType>().getElementType();
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getLargest( APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
SmallVector<Value, 4> outTensorShape; SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input. // `maxpool2d` contains the result of maxpool2d operation over the input.
@ -248,8 +247,7 @@ public:
// `maxpool2d` contains the result of maxpool2d operation over the input. // `maxpool2d` contains the result of maxpool2d operation over the input.
auto smallestFPValueAttr = rewriter.getFloatAttr( auto smallestFPValueAttr = rewriter.getFloatAttr(
elementType, elementType,
APFloat::getLargest( APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true)); /*Negative=*/true));
Value maxPool2d, paddedInput; Value maxPool2d, paddedInput;
SmallVector<Value, 4> outTensorShape; SmallVector<Value, 4> outTensorShape;

View File

@ -122,9 +122,9 @@ public:
loc, loc,
rewriter.getFloatAttr( rewriter.getFloatAttr(
inElementType, inElementType,
APFloat::getLargest( APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), inElementType.cast<mlir::FloatType>().getFloatSemantics(),
true))); /*Negative=*/true)));
} else { } else {
fillValueMax = rewriter.create<arith::ConstantOp>( fillValueMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr( loc, rewriter.getIntegerAttr(
@ -213,7 +213,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
return b.create<arith::ConstantOp>( return b.create<arith::ConstantOp>(
loc, b.getFloatAttr( loc, b.getFloatAttr(
elementType, elementType,
APFloat::getLargest( APFloat::getInf(
elementType.cast<mlir::FloatType>().getFloatSemantics(), elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true))); /*Negative=*/true)));
else if (elementType.isa<mlir::IntegerType>() && else if (elementType.isa<mlir::IntegerType>() &&

View File

@ -13,7 +13,7 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,
%false = torch.constant.bool false %false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2 // CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32 // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index