[onnx] Fix edge condition for `onnx.ReduceMax` (#3598)

For length-0 on `onnx.ReduceMax` the length 0 case was incorrect due to
a copy paste error.
pull/3607/head
Rob Suderman 2024-08-07 10:32:28 -07:00 committed by GitHub
parent 8d95fe9eeb
commit 18139994e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 8 deletions

View File

@ -1328,7 +1328,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
// If any of the input dims are 0 we set to the upper limit:
// If any of the input dims are 0 we set to the lower limit:
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
(llvm::any_of(dataTy.getSizes(),
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
@ -1336,7 +1336,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto dty = dataTy.getDtype();
Value scalar;
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
auto inf =
APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true);
scalar = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(),
@ -1344,14 +1345,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
}
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
auto mx =
auto minInt =
intTy.isSigned()
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
? APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
: APInt::getMinValue(intTy.getIntOrFloatBitWidth());
scalar = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
mx.getSExtValue()));
minInt.getSExtValue()));
}
llvm::SmallVector<Value> fillDims;

View File

@ -644,7 +644,7 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
@ -660,7 +660,7 @@ func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg
// CHECK-LABEL: func.func @test_reduce_max_empty_set_int
func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
// CHECK-DAG: %[[INF:.+]] = torch.constant.int -2147483648
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4