mirror of https://github.com/llvm/torch-mlir
[onnx] Fix edge condition for `onnx.ReduceMax` (#3598)
For length-0 on `onnx.ReduceMax` the length 0 case was incorrect due to a copy paste error.pull/3607/head
parent
8d95fe9eeb
commit
18139994e8
|
@ -1328,7 +1328,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
||||
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
||||
|
||||
// If any of the input dims are 0 we set to the upper limit:
|
||||
// If any of the input dims are 0 we set to the lower limit:
|
||||
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
||||
(llvm::any_of(dataTy.getSizes(),
|
||||
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
||||
|
@ -1336,7 +1336,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
auto dty = dataTy.getDtype();
|
||||
Value scalar;
|
||||
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
||||
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
||||
auto inf =
|
||||
APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true);
|
||||
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getFloatAttr(rewriter.getF64Type(),
|
||||
|
@ -1344,14 +1345,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
}
|
||||
|
||||
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
||||
auto mx =
|
||||
auto minInt =
|
||||
intTy.isSigned()
|
||||
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
||||
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
||||
? APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
|
||||
: APInt::getMinValue(intTy.getIntOrFloatBitWidth());
|
||||
scalar = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), torchIntTy,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
mx.getSExtValue()));
|
||||
minInt.getSExtValue()));
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value> fillDims;
|
||||
|
|
|
@ -644,7 +644,7 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
|
||||
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
|
||||
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000
|
||||
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
|
||||
|
@ -660,7 +660,7 @@ func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg
|
|||
|
||||
// CHECK-LABEL: func.func @test_reduce_max_empty_set_int
|
||||
func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
|
||||
// CHECK-DAG: %[[INF:.+]] = torch.constant.int -2147483648
|
||||
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
|
||||
|
|
Loading…
Reference in New Issue