mirror of https://github.com/llvm/torch-mlir
[onnx] Migrate `onnx.ReduceMax` to match `onnx.ReduceMin` (#2981)
This mostly copy-pastes the reduce minimum implementation to reduce max to improve test coverage. We also improve the aten lowering for min/max dim for unsigned types.pull/2992/head
parent
ea76dd12ba
commit
a78659742a
|
@ -758,107 +758,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
|
||||||
"ReduceMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
|
||||||
Torch::ValueTensorType resultType;
|
|
||||||
SmallVector<Value, 1> operands;
|
|
||||||
int64_t keepDims, noop_with_empty_axes;
|
|
||||||
|
|
||||||
if (binder.tensorOperandsList(operands) ||
|
|
||||||
binder.tensorResultType(resultType) ||
|
|
||||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
||||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
|
||||||
0))
|
|
||||||
return failure();
|
|
||||||
Value data = operands[0];
|
|
||||||
|
|
||||||
if (operands.size() == 1) {
|
|
||||||
if (noop_with_empty_axes == 0) {
|
|
||||||
MLIRContext *context = binder.op->getContext();
|
|
||||||
int rank =
|
|
||||||
data.getType().cast<Torch::ValueTensorType>().getSizes().size();
|
|
||||||
SmallVector<Value, 1> dims;
|
|
||||||
for (int i = 0; i < rank; i++) {
|
|
||||||
dims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
|
||||||
}
|
|
||||||
Value dimsList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::IntType::get(context)), dims);
|
|
||||||
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
||||||
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
||||||
binder.getLoc(), keepDimsConstInt);
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
|
|
||||||
binder.op, resultType, data, /*dim=*/dimsList,
|
|
||||||
/*keepdim=*/keepDimsBool);
|
|
||||||
} else {
|
|
||||||
rewriter.replaceOp(binder.op, data);
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value axes = operands[1];
|
|
||||||
|
|
||||||
SmallVector<Value> dimList;
|
|
||||||
|
|
||||||
Torch::BaseTensorType axesType =
|
|
||||||
axes.getType().cast<Torch::BaseTensorType>();
|
|
||||||
SmallVector<int64_t> selectSizes;
|
|
||||||
selectSizes.push_back(1);
|
|
||||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
|
||||||
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
|
||||||
auto sizes =
|
|
||||||
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
|
||||||
|
|
||||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
||||||
int64_t adjustmentInt =
|
|
||||||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
||||||
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
||||||
adjustmentInt));
|
|
||||||
for (int i = 0; i < sizes[0]; i++) {
|
|
||||||
// Go through the axes list and get each dim in the list
|
|
||||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
||||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
||||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
||||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
||||||
// deal with neg axis: if (axis < 0) axis += rank
|
|
||||||
Value isNegative =
|
|
||||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
|
||||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
|
||||||
isNegative);
|
|
||||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
|
||||||
binder.getLoc(), isNegative, adjustment);
|
|
||||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
|
||||||
binder.getLoc(), dim, finalOffset);
|
|
||||||
dimList.push_back(finalDim);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
||||||
dimList);
|
|
||||||
|
|
||||||
Value keepDimBool;
|
|
||||||
if (keepDims == 1) {
|
|
||||||
keepDimBool =
|
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
||||||
} else {
|
|
||||||
keepDimBool =
|
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
|
|
||||||
binder.op, resultType, data, dimValueList, keepDimBool);
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"ReduceSum", 13,
|
"ReduceSum", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
@ -1102,6 +1001,159 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
/*dtype=*/noneVal);
|
/*dtype=*/noneVal);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"ReduceMax", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
// AtenAmaxOp allows us to pass a list of dims
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value data;
|
||||||
|
Value axes;
|
||||||
|
int64_t keepDims;
|
||||||
|
int64_t noop_with_empty_axes;
|
||||||
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||||
|
0))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
||||||
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
||||||
|
|
||||||
|
// If any of the input dims are 0 we set to the upper limit:
|
||||||
|
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
||||||
|
(llvm::any_of(dataTy.getSizes(),
|
||||||
|
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
||||||
|
keepDims)) {
|
||||||
|
auto dty = dataTy.getDtype();
|
||||||
|
Value scalar;
|
||||||
|
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
||||||
|
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
||||||
|
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(),
|
||||||
|
inf.convertToDouble()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
||||||
|
auto mx =
|
||||||
|
intTy.isSigned()
|
||||||
|
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
||||||
|
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
||||||
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy,
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
mx.getSExtValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value> fillDims;
|
||||||
|
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
|
||||||
|
auto staticDim = resultType.getSizes()[i];
|
||||||
|
if (staticDim != Torch::kUnknownSize) {
|
||||||
|
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy,
|
||||||
|
rewriter.getI64IntegerAttr(staticDim)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
||||||
|
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy, data, iv));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
||||||
|
binder.op, resultType, fillDimsList, scalar, none, none, none,
|
||||||
|
none);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Previous version of the operation had the axes as an attribute:
|
||||||
|
SmallVector<Value> axesList;
|
||||||
|
llvm::SmallVector<int64_t> axesAttr;
|
||||||
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
||||||
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
||||||
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy,
|
||||||
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the axes values from the axes operand:
|
||||||
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||||
|
Torch::BaseTensorType axesType =
|
||||||
|
axes.getType().cast<Torch::BaseTensorType>();
|
||||||
|
SmallVector<int64_t> selectSizes{1};
|
||||||
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
|
selectSizes, axesType.getOptionalDtype());
|
||||||
|
auto sizes = axesType.getSizes();
|
||||||
|
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
// Extract the value of each axes:
|
||||||
|
for (int i = 0; i < sizes[0]; i++) {
|
||||||
|
// Go through the axes list and get each dim in the list
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
axesList.push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the noop case:
|
||||||
|
if (axesList.empty() && noop_with_empty_axes) {
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with case when no axes arg is passed but not a noop:
|
||||||
|
if (axesList.empty()) {
|
||||||
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
||||||
|
.getSizes()
|
||||||
|
.size();
|
||||||
|
for (int i = 0; i < numDims; i++) {
|
||||||
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
axesList.push_back(curr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle negative axis:
|
||||||
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
||||||
|
torchIntTy, data);
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getI64IntegerAttr(0));
|
||||||
|
for (Value &axes : axesList) {
|
||||||
|
Value isNegative =
|
||||||
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
||||||
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||||
|
isNegative);
|
||||||
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
|
binder.getLoc(), isNegative, rankVal);
|
||||||
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
||||||
|
finalOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
||||||
|
Value keepDimBool =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenAmaxOp>(
|
||||||
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"ReduceMin", 13,
|
"ReduceMin", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
|
|
@ -87,6 +87,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
|
||||||
Type inElementType = inputType.getElementType();
|
Type inElementType = inputType.getElementType();
|
||||||
|
bool isUnsigned = false;
|
||||||
if (!inElementType.isa<mlir::FloatType>()) {
|
if (!inElementType.isa<mlir::FloatType>()) {
|
||||||
if (inElementType.isa<mlir::IntegerType>()) {
|
if (inElementType.isa<mlir::IntegerType>()) {
|
||||||
auto integerTy = op.getSelf()
|
auto integerTy = op.getSelf()
|
||||||
|
@ -94,10 +95,7 @@ public:
|
||||||
.template cast<BaseTensorType>()
|
.template cast<BaseTensorType>()
|
||||||
.getDtype()
|
.getDtype()
|
||||||
.template dyn_cast<mlir::IntegerType>();
|
.template dyn_cast<mlir::IntegerType>();
|
||||||
if (integerTy.isUnsigned())
|
isUnsigned = integerTy.isUnsigned();
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, opName + " to linalg.* requires input element type "
|
|
||||||
"to be signed in case of integer");
|
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, opName + " to linalg.* requires Float or Integer "
|
op, opName + " to linalg.* requires Float or Integer "
|
||||||
|
@ -130,12 +128,17 @@ public:
|
||||||
APFloat::getInf(
|
APFloat::getInf(
|
||||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||||
/*Negative=*/isMax)));
|
/*Negative=*/isMax)));
|
||||||
} else {
|
} else if (!isUnsigned) {
|
||||||
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
|
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
|
||||||
auto init = isMax ? APSInt::getSignedMinValue(width)
|
auto init = isMax ? APSInt::getSignedMinValue(width)
|
||||||
: APSInt::getSignedMaxValue(width);
|
: APSInt::getSignedMaxValue(width);
|
||||||
fillValue = rewriter.create<arith::ConstantOp>(
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(inElementType, init));
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
||||||
|
} else if (isUnsigned) {
|
||||||
|
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
|
||||||
|
auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width);
|
||||||
|
fillValue = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(inElementType, init));
|
||||||
}
|
}
|
||||||
|
|
||||||
Value filledTensorVal =
|
Value filledTensorVal =
|
||||||
|
@ -193,14 +196,26 @@ public:
|
||||||
} else {
|
} else {
|
||||||
arith::CmpIPredicate predType;
|
arith::CmpIPredicate predType;
|
||||||
if (isMax) {
|
if (isMax) {
|
||||||
predType = arith::CmpIPredicate::sgt;
|
predType = isUnsigned ? arith::CmpIPredicate::ugt
|
||||||
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
|
: arith::CmpIPredicate::sgt;
|
||||||
|
if (isUnsigned) {
|
||||||
|
resultVal = rewriter.create<arith::MaxUIOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
} else {
|
||||||
|
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
|
||||||
|
oldValue);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
predType = isUnsigned ? arith::CmpIPredicate::ult
|
||||||
|
: arith::CmpIPredicate::slt;
|
||||||
|
if (isUnsigned) {
|
||||||
|
resultVal = rewriter.create<arith::MinUIOp>(nestedLoc, newValue,
|
||||||
oldValue);
|
oldValue);
|
||||||
} else {
|
} else {
|
||||||
predType = arith::CmpIPredicate::slt;
|
|
||||||
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
|
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
|
||||||
oldValue);
|
oldValue);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
|
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
|
||||||
newValue, oldValue);
|
newValue, oldValue);
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
Type resultType = tensorType.getWithSizesAndDtype(
|
Type resultType = tensorType.getWithSizesAndDtype(
|
||||||
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
|
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||||
: llvm::ArrayRef(sizes),
|
: llvm::ArrayRef(sizes),
|
||||||
tensorType.getOptionalDtype());
|
tensorType.getOptionalDtype());
|
||||||
return resultType;
|
return resultType;
|
||||||
|
|
|
@ -1515,9 +1515,6 @@ ONNX_XFAIL_SET = {
|
||||||
"BroadcastToModule_basic",
|
"BroadcastToModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
"MoveDimIntNegativeIndexModule_basic",
|
"MoveDimIntNegativeIndexModule_basic",
|
||||||
"ReduceAmaxKeepDim_basic",
|
|
||||||
"ReduceMaxKeepDimReturnBoth_basic",
|
|
||||||
"ReduceMaxNegativeDim_basic",
|
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
|
||||||
# Failure - onnx_export
|
# Failure - onnx_export
|
||||||
|
@ -2122,18 +2119,8 @@ ONNX_XFAIL_SET = {
|
||||||
"TriuBroadcastModule_basic",
|
"TriuBroadcastModule_basic",
|
||||||
"TriuModule_basic",
|
"TriuModule_basic",
|
||||||
|
|
||||||
# Failure - rankless return
|
# Failure - incorrect dtype
|
||||||
"ReduceAmaxMultiDim_basic",
|
|
||||||
"ReduceAmaxOutOfOrderDim_basic",
|
|
||||||
"ReduceAmaxSingleDim_basic",
|
|
||||||
"ReduceMaxAllDims_basic",
|
|
||||||
"ReduceMaxAlongDimNegative_basic",
|
|
||||||
"ReduceMaxAlongDimSignedInt_basic",
|
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
"ReduceMaxAlongDim_basic",
|
|
||||||
"ReduceMaxFloatModule_basic",
|
|
||||||
"ReduceMaxSignedIntModule_basic",
|
|
||||||
"ReduceMaxUnsignedIntModule_basic",
|
|
||||||
|
|
||||||
# Failure - torch.aten.view lower
|
# Failure - torch.aten.view lower
|
||||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||||
|
|
|
@ -747,65 +747,121 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_max_keepdims_example
|
// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp
|
||||||
func.func @test_reduce_max_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
|
||||||
// CHECK: %[[RANK:.*]] = torch.constant.int 3
|
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[SELECT_DIM0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
|
||||||
// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT_DIM0]] : !torch.vtensor<[1],si64> -> !torch.int
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
// CHECK: %[[LTZERO_0:.*]] = torch.aten.lt.int %[[ITEM0]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
|
||||||
// CHECK: %[[ISNEG_0:.*]] = torch.aten.Int.bool %[[LTZERO_0]] : !torch.bool -> !torch.int
|
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
|
||||||
// CHECK: %[[ADJUSTMENT_0:.*]] = torch.aten.mul.int %[[ISNEG_0]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
// CHECK: return %[[FULL]]
|
||||||
// CHECK: %[[FINAL_0:.*]] = torch.aten.add.int %[[ITEM0]], %[[ADJUSTMENT_0]] : !torch.int, !torch.int -> !torch.int
|
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
return %0 : !torch.vtensor<[2,1,4],f32>
|
||||||
// CHECK: %[[SELECT_DIM1:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
}
|
||||||
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT_DIM1]] : !torch.vtensor<[1],si64> -> !torch.int
|
|
||||||
// CHECK: %[[LTZERO_1:.*]] = torch.aten.lt.int %[[ITEM1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: %[[ISNEG_1:.*]] = torch.aten.Int.bool %[[LTZERO_1]] : !torch.bool -> !torch.int
|
|
||||||
// CHECK: %[[ADJUSTMENT_1:.*]] = torch.aten.mul.int %[[ISNEG_1]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[FINAL_1:.*]] = torch.aten.add.int %[[ITEM1]], %[[ADJUSTMENT_1]] : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL_0]], %[[FINAL_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[KEEPDIMS:.*]] = torch.constant.bool true
|
|
||||||
// CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,1],f32>
|
|
||||||
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32>
|
|
||||||
return %0 : !torch.vtensor<[3,1,1],f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_max_default_axes_keepdim_example
|
// CHECK-LABEL: func.func @test_reduce_max_empty_set_int
|
||||||
func.func @test_reduce_max_default_axes_keepdim_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
|
||||||
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
// CHECK: %[[KEEPDIMS:.*]] = torch.aten.Bool.int %[[INT1_0]] : !torch.int -> !torch.bool
|
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
|
||||||
// CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1],f32>
|
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
|
||||||
%0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32>
|
// CHECK: return %[[FULL]]
|
||||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32>
|
||||||
}
|
return %0 : !torch.vtensor<[2,1,4],si32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_max_do_not_keepdims_example
|
// CHECK-LABEL: func.func @test_reduce_max_bool_inputs
|
||||||
func.func @test_reduce_max_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[RANK:.*]] = torch.constant.int 3
|
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||||
// CHECK: %[[SELECT_DIM:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT_DIM]] : !torch.vtensor<[1],si64> -> !torch.int
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
// CHECK: %[[LTZERO:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[ISNEG:.*]] = torch.aten.Int.bool %[[LTZERO]] : !torch.bool -> !torch.int
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
// CHECK: %[[ADJUSTMENT:.*]] = torch.aten.mul.int %[[ISNEG]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
// CHECK: %[[FINAL:.*]] = torch.aten.add.int %[[ITEM]], %[[ADJUSTMENT]] : !torch.int, !torch.int -> !torch.int
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||||
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL]] : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[FALSE]] : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,2],f32>
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
|
||||||
return %0 : !torch.vtensor<[3,2],f32>
|
// CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1>
|
||||||
}
|
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,1],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims
|
||||||
|
func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||||
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||||
|
// CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1>
|
||||||
|
%0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1>
|
||||||
|
return %0 : !torch.vtensor<[4],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_max_all_dims_default
|
||||||
|
func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]]
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[],i1>
|
||||||
|
// CHECK: return %[[MAX]] : !torch.vtensor<[],i1>
|
||||||
|
%0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1>
|
||||||
|
return %0 : !torch.vtensor<[],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||||
|
// CHECK: return %[[AMAX]]
|
||||||
|
%0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1>
|
||||||
|
return %0 : !torch.vtensor<[4],i1>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
@ -1064,8 +1120,8 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_all_dims_default
|
// CHECK-LABEL: func.func @test_reduce_min_all_dims_default
|
||||||
func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
|
Loading…
Reference in New Issue