[Stablehlo] fix reduce max init_value with -inf (#2064)

* [Stablehlo] fix reduce max init_value with -inf

* update
pull/2097/head snapshot-20230507.831
Yuanqiang Liu 2023-05-07 03:05:51 +08:00 committed by GitHub
parent 0d0366c319
commit 0096ceae2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 5 deletions

View File

@ -56,7 +56,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getLargest(
constType, {APFloat::getInf(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,

View File

@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getLargest(
constType, {APFloat::getInf(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,

View File

@ -13,7 +13,7 @@
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
@ -45,7 +45,7 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: %[[T5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64