mirror of https://github.com/llvm/torch-mlir
[Stablehlo] fix reduce max init_value with -inf (#2064)
* [Stablehlo] fix reduce max init_value with -inf * updatepull/2097/head snapshot-20230507.831
parent
0d0366c319
commit
0096ceae2f
|
@ -56,7 +56,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getLargest(
|
||||
constType, {APFloat::getInf(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
|
|
|
@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
if (isa<AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getLargest(
|
||||
constType, {APFloat::getInf(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
|
@ -45,7 +45,7 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({
|
||||
// CHECK: ^bb0(%[[VAL_8:.*]]: tensor<f32>, %[[VAL_9:.*]]: tensor<f32>):
|
||||
// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor<f32>
|
||||
|
@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -
|
|||
// CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
|
||||
// CHECK: %[[T5:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64
|
||||
|
|
Loading…
Reference in New Issue