[onnx] Migrate `onnx.ReduceMax` to match `onnx.ReduceMin` (#2981)

This mostly copy-pastes the reduce minimum implementation to reduce max
to improve test coverage. We also improve the aten lowering for min/max
dim for unsigned types.
pull/2992/head
Rob Suderman 2024-03-06 16:48:21 -08:00 committed by GitHub
parent ea76dd12ba
commit a78659742a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 293 additions and 183 deletions

View File

@ -758,107 +758,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
return success();
});
patterns.onOp(
"ReduceMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value, 1> operands;
int64_t keepDims, noop_with_empty_axes;
if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0))
return failure();
Value data = operands[0];
if (operands.size() == 1) {
if (noop_with_empty_axes == 0) {
MLIRContext *context = binder.op->getContext();
int rank =
data.getType().cast<Torch::ValueTensorType>().getSizes().size();
SmallVector<Value, 1> dims;
for (int i = 0; i < rank; i++) {
dims.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
Value dimsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)), dims);
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), keepDimsConstInt);
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
binder.op, resultType, data, /*dim=*/dimsList,
/*keepdim=*/keepDimsBool);
} else {
rewriter.replaceOp(binder.op, data);
}
return success();
}
Value axes = operands[1];
SmallVector<Value> dimList;
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = axesType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
auto sizes =
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
int64_t adjustmentInt =
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adjustmentInt));
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
// deal with neg axis: if (axis < 0) axis += rank
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), dim, finalOffset);
dimList.push_back(finalDim);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
dimList);
Value keepDimBool;
if (keepDims == 1) {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
} else {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
}
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
binder.op, resultType, data, dimValueList, keepDimBool);
return success();
});
patterns.onOp(
"ReduceSum", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
@ -1102,6 +1001,159 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*dtype=*/noneVal);
return success();
});
patterns.onOp(
"ReduceMax", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
// AtenAmaxOp allows us to pass a list of dims
Torch::ValueTensorType resultType;
Value data;
Value axes;
int64_t keepDims;
int64_t noop_with_empty_axes;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0))
return failure();
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
// If any of the input dims are 0 we set to the upper limit:
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
(llvm::any_of(dataTy.getSizes(),
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
keepDims)) {
auto dty = dataTy.getDtype();
Value scalar;
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
scalar = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(),
inf.convertToDouble()));
}
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
auto mx =
intTy.isSigned()
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
scalar = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
mx.getSExtValue()));
}
llvm::SmallVector<Value> fillDims;
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
auto staticDim = resultType.getSizes()[i];
if (staticDim != Torch::kUnknownSize) {
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getI64IntegerAttr(staticDim)));
continue;
}
Value iv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), torchIntTy, data, iv));
}
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
binder.op, resultType, fillDimsList, scalar, none, none, none,
none);
return success();
}
// Previous version of the operation had the axes as an attribute:
SmallVector<Value> axesList;
llvm::SmallVector<int64_t> axesAttr;
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getI64IntegerAttr(axesAttr[i])));
}
}
// Extract the axes values from the axes operand:
if (!binder.tensorOperandAtIndex(axes, 1)) {
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype());
auto sizes = axesType.getSizes();
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
// Extract the value of each axes:
for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
axesList.push_back(dim);
}
}
// Handle the noop case:
if (axesList.empty() && noop_with_empty_axes) {
rewriter.replaceOp(binder.op, data);
return success();
}
// Deal with case when no axes arg is passed but not a noop:
if (axesList.empty()) {
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
axesList.push_back(curr);
}
}
// Handle negative axis:
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
torchIntTy, data);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(0));
for (Value &axes : axesList) {
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
finalOffset);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
Value keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
binder.op, resultType, data, dimValueList, keepDimBool);
return success();
});
patterns.onOp(
"ReduceMin", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -87,6 +87,7 @@ public:
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
Type inElementType = inputType.getElementType();
bool isUnsigned = false;
if (!inElementType.isa<mlir::FloatType>()) {
if (inElementType.isa<mlir::IntegerType>()) {
auto integerTy = op.getSelf()
@ -94,10 +95,7 @@ public:
.template cast<BaseTensorType>()
.getDtype()
.template dyn_cast<mlir::IntegerType>();
if (integerTy.isUnsigned())
return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires input element type "
"to be signed in case of integer");
isUnsigned = integerTy.isUnsigned();
} else {
return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires Float or Integer "
@ -130,12 +128,17 @@ public:
APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/isMax)));
} else {
} else if (!isUnsigned) {
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
auto init = isMax ? APSInt::getSignedMinValue(width)
: APSInt::getSignedMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init));
} else if (isUnsigned) {
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init));
}
Value filledTensorVal =
@ -193,13 +196,25 @@ public:
} else {
arith::CmpIPredicate predType;
if (isMax) {
predType = arith::CmpIPredicate::sgt;
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
oldValue);
predType = isUnsigned ? arith::CmpIPredicate::ugt
: arith::CmpIPredicate::sgt;
if (isUnsigned) {
resultVal = rewriter.create<arith::MaxUIOp>(nestedLoc, newValue,
oldValue);
} else {
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
oldValue);
}
} else {
predType = arith::CmpIPredicate::slt;
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
oldValue);
predType = isUnsigned ? arith::CmpIPredicate::ult
: arith::CmpIPredicate::slt;
if (isUnsigned) {
resultVal = rewriter.create<arith::MinUIOp>(nestedLoc, newValue,
oldValue);
} else {
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
oldValue);
}
}
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
newValue, oldValue);

View File

@ -71,8 +71,8 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
}
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(sizes),
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(sizes),
tensorType.getOptionalDtype());
return resultType;
}

View File

@ -1515,9 +1515,6 @@ ONNX_XFAIL_SET = {
"BroadcastToModule_basic",
"ExpandModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"ReduceAmaxKeepDim_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ViewSizeFromOtherTensor_basic",
# Failure - onnx_export
@ -2122,18 +2119,8 @@ ONNX_XFAIL_SET = {
"TriuBroadcastModule_basic",
"TriuModule_basic",
# Failure - rankless return
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimSignedInt_basic",
# Failure - incorrect dtype
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
# Failure - torch.aten.view lower
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",

View File

@ -747,65 +747,121 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
// -----
// CHECK-LABEL: func.func @test_reduce_max_keepdims_example
func.func @test_reduce_max_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[RANK:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[SELECT_DIM0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT_DIM0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[LTZERO_0:.*]] = torch.aten.lt.int %[[ITEM0]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[ISNEG_0:.*]] = torch.aten.Int.bool %[[LTZERO_0]] : !torch.bool -> !torch.int
// CHECK: %[[ADJUSTMENT_0:.*]] = torch.aten.mul.int %[[ISNEG_0]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[FINAL_0:.*]] = torch.aten.add.int %[[ITEM0]], %[[ADJUSTMENT_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[SELECT_DIM1:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT_DIM1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[LTZERO_1:.*]] = torch.aten.lt.int %[[ITEM1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[ISNEG_1:.*]] = torch.aten.Int.bool %[[LTZERO_1]] : !torch.bool -> !torch.int
// CHECK: %[[ADJUSTMENT_1:.*]] = torch.aten.mul.int %[[ISNEG_1]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[FINAL_1:.*]] = torch.aten.add.int %[[ITEM1]], %[[ADJUSTMENT_1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL_0]], %[[FINAL_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[KEEPDIMS:.*]] = torch.constant.bool true
// CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,1],f32>
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32>
return %0 : !torch.vtensor<[3,1,1],f32>
}
// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK: return %[[FULL]]
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
return %0 : !torch.vtensor<[2,1,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_max_default_axes_keepdim_example
func.func @test_reduce_max_default_axes_keepdim_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[KEEPDIMS:.*]] = torch.aten.Bool.int %[[INT1_0]] : !torch.int -> !torch.bool
// CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}
// CHECK-LABEL: func.func @test_reduce_max_empty_set_int
func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK: return %[[FULL]]
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32>
return %0 : !torch.vtensor<[2,1,4],si32>
}
// -----
// CHECK-LABEL: func.func @test_reduce_max_do_not_keepdims_example
func.func @test_reduce_max_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[RANK:.*]] = torch.constant.int 3
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[SELECT_DIM:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT_DIM]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[LTZERO:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[ISNEG:.*]] = torch.aten.Int.bool %[[LTZERO]] : !torch.bool -> !torch.int
// CHECK: %[[ADJUSTMENT:.*]] = torch.aten.mul.int %[[ISNEG]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[FINAL:.*]] = torch.aten.add.int %[[ITEM]], %[[ADJUSTMENT]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[FALSE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}
// CHECK-LABEL: func.func @test_reduce_max_bool_inputs
func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[IDX:.+]] = torch.constant.int 0
// CHECK: %[[SZ:.+]] = torch.constant.int 0
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[C0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
// CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1>
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
return %0 : !torch.vtensor<[4,1],i1>
}
// -----
// CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims
func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[IDX:.+]] = torch.constant.int 0
// CHECK: %[[SZ:.+]] = torch.constant.int 0
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[C0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
// CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1>
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1>
return %0 : !torch.vtensor<[4],i1>
}
// -----
// CHECK-LABEL: func.func @test_reduce_max_all_dims_default
func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[I0:.+]] = torch.constant.int 0
// CHECK: %[[I1:.+]] = torch.constant.int 1
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[C0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]]
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[MAX]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
// CHECK: return %[[AMAX]]
%0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1>
return %0 : !torch.vtensor<[4],i1>
}
// -----
@ -1064,8 +1120,8 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1
// -----
// CHECK-LABEL: func.func @test_reduce_all_dims_default
func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-LABEL: func.func @test_reduce_min_all_dims_default
func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[I0:.+]] = torch.constant.int 0
// CHECK: %[[I1:.+]] = torch.constant.int 1
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int