mirror of https://github.com/llvm/torch-mlir
[Linalg] fix lowering reduce max with -inf (#2097)
parent
11a91b9d14
commit
ef6dae6ae2
|
@ -170,9 +170,8 @@ public:
|
|||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
Value maxPool2d, paddedInput;
|
||||
|
@ -248,9 +247,8 @@ public:
|
|||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
Value maxPool2d, paddedInput;
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
|
|
|
@ -122,9 +122,9 @@ public:
|
|||
loc,
|
||||
rewriter.getFloatAttr(
|
||||
inElementType,
|
||||
APFloat::getLargest(
|
||||
APFloat::getInf(
|
||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
true)));
|
||||
/*Negative=*/true)));
|
||||
} else {
|
||||
fillValueMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(
|
||||
|
@ -213,7 +213,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
|||
return b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(
|
||||
APFloat::getInf(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative=*/true)));
|
||||
else if (elementType.isa<mlir::IntegerType>() &&
|
||||
|
|
|
@ -13,7 +13,7 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,
|
|||
%false = torch.constant.bool false
|
||||
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
|
||||
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
|
||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
|
||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
|
||||
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
|
||||
|
|
Loading…
Reference in New Issue