mirror of https://github.com/llvm/torch-mlir
[Linalg] fix lowering reduce max with -inf (#2097)
parent
11a91b9d14
commit
ef6dae6ae2
|
@ -170,8 +170,7 @@ public:
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getLargest(
|
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
|
||||||
/*Negative=*/true));
|
/*Negative=*/true));
|
||||||
SmallVector<Value, 4> outTensorShape;
|
SmallVector<Value, 4> outTensorShape;
|
||||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||||
|
@ -248,8 +247,7 @@ public:
|
||||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||||
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
auto smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getLargest(
|
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
|
||||||
/*Negative=*/true));
|
/*Negative=*/true));
|
||||||
Value maxPool2d, paddedInput;
|
Value maxPool2d, paddedInput;
|
||||||
SmallVector<Value, 4> outTensorShape;
|
SmallVector<Value, 4> outTensorShape;
|
||||||
|
|
|
@ -122,9 +122,9 @@ public:
|
||||||
loc,
|
loc,
|
||||||
rewriter.getFloatAttr(
|
rewriter.getFloatAttr(
|
||||||
inElementType,
|
inElementType,
|
||||||
APFloat::getLargest(
|
APFloat::getInf(
|
||||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
true)));
|
/*Negative=*/true)));
|
||||||
} else {
|
} else {
|
||||||
fillValueMax = rewriter.create<arith::ConstantOp>(
|
fillValueMax = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(
|
loc, rewriter.getIntegerAttr(
|
||||||
|
@ -213,7 +213,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
|
||||||
return b.create<arith::ConstantOp>(
|
return b.create<arith::ConstantOp>(
|
||||||
loc, b.getFloatAttr(
|
loc, b.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getLargest(
|
APFloat::getInf(
|
||||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*Negative=*/true)));
|
/*Negative=*/true)));
|
||||||
else if (elementType.isa<mlir::IntegerType>() &&
|
else if (elementType.isa<mlir::IntegerType>() &&
|
||||||
|
|
|
@ -13,7 +13,7 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
|
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
|
||||||
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
|
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
|
||||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
|
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
|
||||||
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||||
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
|
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
|
||||||
|
|
Loading…
Reference in New Issue