mirror of https://github.com/llvm/torch-mlir
[onnx] Fix ReduceMean lowering to torch (#2956)
Torch lowering only supported the most recent version. Refactored the lowering so more easily handle default values and optional operands / attributes.pull/2960/head
parent
d541779f37
commit
4a7a7d76f8
|
@ -1104,129 +1104,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value axes;
|
Value axes;
|
||||||
int64_t keepDims;
|
int64_t keepDims;
|
||||||
int64_t noop_with_empty_axes;
|
int64_t noop_with_empty_axes;
|
||||||
// Deal with case when no axes arg is passed
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
if (binder.op->getNumOperands() == 1) {
|
|
||||||
if (binder.tensorOperand(data) ||
|
|
||||||
binder.tensorResultType(resultType) ||
|
|
||||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
|
||||||
binder.s64IntegerAttr(noop_with_empty_axes,
|
|
||||||
"noop_with_empty_axes", 0))
|
|
||||||
return failure();
|
|
||||||
if (noop_with_empty_axes == 0) {
|
|
||||||
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
|
||||||
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
|
||||||
binder.getLoc(), keepDimsConstInt);
|
|
||||||
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
||||||
.getSizes()
|
|
||||||
.size();
|
|
||||||
SmallVector<Value> axesList;
|
|
||||||
for (int i = 0; i < numDims; i++) {
|
|
||||||
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
||||||
axesList.push_back(curr);
|
|
||||||
}
|
|
||||||
Value axesValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(
|
|
||||||
Torch::IntType::get(binder.op->getContext())),
|
|
||||||
axesList);
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
||||||
binder.op, resultType, data, axesValueList, keepDimsBool);
|
|
||||||
} else {
|
|
||||||
rewriter.replaceOp(binder.op, data);
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
if (binder.tensorOperands(data, axes) ||
|
|
||||||
binder.tensorResultType(resultType) ||
|
binder.tensorResultType(resultType) ||
|
||||||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
|
||||||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
|
||||||
0))
|
0))
|
||||||
return failure();
|
return failure();
|
||||||
Torch::BaseTensorType axesType =
|
|
||||||
axes.getType().cast<Torch::BaseTensorType>();
|
auto dataTy = cast<Torch::BaseTensorType>(data.getType());
|
||||||
SmallVector<Value> dimList;
|
Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
|
||||||
SmallVector<int64_t> selectSizes;
|
|
||||||
selectSizes.push_back(1);
|
// If any of the input dims are 0 we set to the upper limit:
|
||||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
|
||||||
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype());
|
(llvm::any_of(dataTy.getSizes(),
|
||||||
auto sizes =
|
[](int64_t d) { return d == Torch::kUnknownSize; }) ||
|
||||||
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes();
|
keepDims)) {
|
||||||
// deal with case when axes is empty
|
auto dty = dataTy.getDtype();
|
||||||
if (sizes.size() == 1 && sizes[0] == 0) {
|
Value scalar;
|
||||||
if (noop_with_empty_axes == 0) {
|
if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
|
||||||
// create dims list with all dims [0, data.getSizes().size())
|
auto inf = APFloat::getInf(fpTy.getFloatSemantics());
|
||||||
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
|
scalar = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
|
rewriter.getFloatAttr(rewriter.getF64Type(),
|
||||||
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
|
inf.convertToDouble()));
|
||||||
binder.getLoc(), keepDimsConstInt);
|
|
||||||
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
|
||||||
.getSizes()
|
|
||||||
.size();
|
|
||||||
for (int i = 0; i < numDims; i++) {
|
|
||||||
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
||||||
dimList.push_back(curr);
|
|
||||||
}
|
|
||||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(
|
|
||||||
Torch::IntType::get(binder.op->getContext())),
|
|
||||||
dimList);
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
|
||||||
binder.op, resultType, data, dimValueList, keepDimsBool);
|
|
||||||
} else {
|
|
||||||
rewriter.replaceOp(binder.op, data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
|
||||||
|
auto mx =
|
||||||
|
intTy.isSigned()
|
||||||
|
? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
|
||||||
|
: APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
|
||||||
|
scalar = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy,
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
mx.getSExtValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value> fillDims;
|
||||||
|
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
|
||||||
|
auto staticDim = resultType.getSizes()[i];
|
||||||
|
if (staticDim != Torch::kUnknownSize) {
|
||||||
|
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy,
|
||||||
|
rewriter.getI64IntegerAttr(staticDim)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
|
||||||
|
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy, data, iv));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
|
||||||
|
binder.op, resultType, fillDimsList, scalar, none, none, none,
|
||||||
|
none);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Previous version of the operation had the axes as an attribute:
|
||||||
|
SmallVector<Value> axesList;
|
||||||
|
llvm::SmallVector<int64_t> axesAttr;
|
||||||
|
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
|
||||||
|
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
|
||||||
|
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), torchIntTy,
|
||||||
|
rewriter.getI64IntegerAttr(axesAttr[i])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the axes values from the axes operand:
|
||||||
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||||
|
Torch::BaseTensorType axesType =
|
||||||
|
axes.getType().cast<Torch::BaseTensorType>();
|
||||||
|
SmallVector<int64_t> selectSizes{1};
|
||||||
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
|
selectSizes, axesType.getOptionalDtype());
|
||||||
|
auto sizes = axesType.getSizes();
|
||||||
|
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
// Extract the value of each axes:
|
||||||
|
for (int i = 0; i < sizes[0]; i++) {
|
||||||
|
// Go through the axes list and get each dim in the list
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
||||||
|
Value dim = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
||||||
|
axesList.push_back(dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the noop case:
|
||||||
|
if (axesList.empty() && noop_with_empty_axes) {
|
||||||
|
rewriter.replaceOp(binder.op, data);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with case when no axes arg is passed but not a noop:
|
||||||
|
if (axesList.empty()) {
|
||||||
|
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
|
||||||
|
.getSizes()
|
||||||
|
.size();
|
||||||
|
for (int i = 0; i < numDims; i++) {
|
||||||
|
Value curr = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
axesList.push_back(curr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle negative axis:
|
||||||
|
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
|
||||||
|
torchIntTy, data);
|
||||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
rewriter.getI64IntegerAttr(0));
|
||||||
int64_t adjustmentInt =
|
for (Value &axes : axesList) {
|
||||||
cast<Torch::ValueTensorType>(data.getType()).getSizes().size();
|
|
||||||
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
||||||
adjustmentInt));
|
|
||||||
// convert axes (tensor) into torch int list while dealing with neg axis
|
|
||||||
for (int i = 0; i < sizes[0]; i++) {
|
|
||||||
// Go through the axes list and get each dim in the list
|
|
||||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
|
||||||
Value extract = rewriter.create<Torch::AtenSelectIntOp>(
|
|
||||||
binder.getLoc(), selectResultType, axes, zero, selectIndex);
|
|
||||||
Value dim = rewriter.create<Torch::AtenItemOp>(
|
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
|
|
||||||
// deal with neg axis: if (axis < 0) axis += rank
|
|
||||||
Value isNegative =
|
Value isNegative =
|
||||||
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero);
|
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
|
||||||
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
|
||||||
isNegative);
|
isNegative);
|
||||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
binder.getLoc(), isNegative, adjustment);
|
binder.getLoc(), isNegative, rankVal);
|
||||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
|
||||||
binder.getLoc(), dim, finalOffset);
|
finalOffset);
|
||||||
dimList.push_back(finalDim);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Value keepDimBool =
|
||||||
dimList);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
|
||||||
Value keepDimBool;
|
|
||||||
if (keepDims == 1) {
|
|
||||||
keepDimBool =
|
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
|
||||||
} else {
|
|
||||||
keepDimBool =
|
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
|
||||||
binder.op, resultType, data, dimValueList, keepDimBool);
|
binder.op, resultType, data, dimValueList, keepDimBool);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -60,18 +60,15 @@ public:
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
RankedTensorType valResultType =
|
auto typec = this->getTypeConverter();
|
||||||
getTypeConverter()
|
auto valResultType =
|
||||||
->convertType(op.getResult(0).getType())
|
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
||||||
.template cast<RankedTensorType>();
|
auto idxResultType =
|
||||||
|
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
||||||
RankedTensorType idxResultType =
|
|
||||||
this->getTypeConverter()
|
|
||||||
->convertType(op.getResult(1).getType())
|
|
||||||
.template cast<RankedTensorType>();
|
|
||||||
RankedTensorType inputType =
|
RankedTensorType inputType =
|
||||||
input.getType().template cast<RankedTensorType>();
|
input.getType().template cast<RankedTensorType>();
|
||||||
Type idxElementType = idxResultType.getElementType();
|
Type idxElementType =
|
||||||
|
getElementTypeOrSelf(typec->convertType(idxResultType));
|
||||||
if (!idxElementType.isa<IntegerType>())
|
if (!idxElementType.isa<IntegerType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, opName + " to linalg.* requires integer-like result type");
|
op, opName + " to linalg.* requires integer-like result type");
|
||||||
|
@ -109,14 +106,12 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Constant op to account for the reduction along dim.
|
// Constant op to account for the reduction along dim.
|
||||||
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape;
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
||||||
if (dim != i) {
|
if (dim != i) {
|
||||||
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
|
||||||
resultShape.push_back(currentDimSize);
|
resultShape.push_back(currentDimSize);
|
||||||
} else if (keepDim)
|
}
|
||||||
resultShape.push_back(c1);
|
|
||||||
}
|
}
|
||||||
// First fill the output buffer for the index.
|
// First fill the output buffer for the index.
|
||||||
Value filledTensorIdx =
|
Value filledTensorIdx =
|
||||||
|
@ -146,27 +141,23 @@ public:
|
||||||
Value filledTensorVal =
|
Value filledTensorVal =
|
||||||
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();
|
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
inputType.getRank(), utils::IteratorType::parallel);
|
||||||
|
iteratorTypes[dim] = utils::IteratorType::reduction;
|
||||||
|
|
||||||
// Create the affine expressions that will be used to
|
// Create the affine expressions that will be used to
|
||||||
// iterate over the input and output tensors.
|
// iterate over the input and output tensors.
|
||||||
// Here we also set the type of iterator: parallel or reduction.
|
// Here we also set the type of iterator: parallel or reduction.
|
||||||
|
|
||||||
SmallVector<AffineExpr> exprs;
|
SmallVector<AffineExpr> exprs;
|
||||||
SmallVector<utils::IteratorType> iteratorTypes;
|
|
||||||
SmallVector<AffineExpr> resultExprs;
|
SmallVector<AffineExpr> resultExprs;
|
||||||
for (auto size :
|
for (auto size :
|
||||||
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
||||||
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
exprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
|
if (unsigned(dim) != size.index())
|
||||||
if (unsigned(dim) == size.index()) {
|
|
||||||
iteratorTypes.push_back(utils::IteratorType::reduction);
|
|
||||||
// If `keepDim`, create affine map to the first element
|
|
||||||
// in the current dimension.
|
|
||||||
if (keepDim)
|
|
||||||
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
|
|
||||||
} else {
|
|
||||||
iteratorTypes.push_back(utils::IteratorType::parallel);
|
|
||||||
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs},
|
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs},
|
||||||
rewriter.getContext());
|
rewriter.getContext());
|
||||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
@ -219,12 +210,58 @@ public:
|
||||||
nestedLoc, ValueRange({resultVal, resultIndex}));
|
nestedLoc, ValueRange({resultVal, resultIndex}));
|
||||||
});
|
});
|
||||||
|
|
||||||
// This cast is required to fix the shape in the case of keepDim=True
|
if (!keepDim) {
|
||||||
Value valuesCast = rewriter.create<tensor::CastOp>(loc, valResultType,
|
Value rVal = rewriter.create<tensor::CastOp>(loc, valResultType,
|
||||||
linalgOp.getResult(0));
|
linalgOp.getResult(0));
|
||||||
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
Value rIdx = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
||||||
linalgOp.getResult(1));
|
linalgOp.getResult(1));
|
||||||
rewriter.replaceOp(op, {valuesCast, idxCast});
|
llvm::SmallVector<Value> res{rVal, rIdx};
|
||||||
|
rewriter.replaceOp(op, res);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> valShape(valResultType.getShape());
|
||||||
|
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
|
||||||
|
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
||||||
|
valShape[i] = valShape[i + 1];
|
||||||
|
idxShape[i] = idxShape[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
valShape.resize(valShape.size() - 1);
|
||||||
|
idxShape.resize(idxShape.size() - 1);
|
||||||
|
|
||||||
|
Value rVal = rewriter.create<tensor::CastOp>(
|
||||||
|
loc, valResultType.clone(valShape), linalgOp.getResult(0));
|
||||||
|
Value rIdx = rewriter.create<tensor::CastOp>(
|
||||||
|
loc, idxResultType.clone(idxShape), linalgOp.getResult(1));
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation(valShape.size());
|
||||||
|
if (reassociation.size() > 0) {
|
||||||
|
for (int i = 0; i < dim; ++i)
|
||||||
|
reassociation[i].push_back(i);
|
||||||
|
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
|
||||||
|
for (int i = dim, s = reassociation.size(); i < s; ++i)
|
||||||
|
reassociation[i].push_back(i + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
valShape.push_back(0);
|
||||||
|
idxShape.push_back(0);
|
||||||
|
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
|
||||||
|
valShape[i + 1] = valShape[i];
|
||||||
|
idxShape[i + 1] = idxShape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
valShape[dim] = 1;
|
||||||
|
idxShape[dim] = 1;
|
||||||
|
|
||||||
|
Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, valResultType, rVal, reassociation);
|
||||||
|
|
||||||
|
Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, idxResultType, rIdx, reassociation);
|
||||||
|
|
||||||
|
llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
|
||||||
|
rewriter.replaceOp(op, unsqueezes);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1316,6 +1316,57 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(Torch::AtenAminOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
llvm::SmallVector<int64_t> dimList;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dims not foldable constants");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool keepdim;
|
||||||
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>());
|
||||||
|
|
||||||
|
Value reduction = op.getSelf();
|
||||||
|
auto resultTy = cast<Torch::ValueTensorType>(op.getType());
|
||||||
|
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType());
|
||||||
|
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes());
|
||||||
|
|
||||||
|
for (auto dim : dimList) {
|
||||||
|
auto dimValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(dim));
|
||||||
|
reductionShape[dim] = 1;
|
||||||
|
if (!keepdim) {
|
||||||
|
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i)
|
||||||
|
reductionShape[i] = reductionShape[i + 1];
|
||||||
|
reductionShape.resize(reductionShape.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
reductionTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
reductionShape, resultTy.getOptionalDtype());
|
||||||
|
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
|
||||||
|
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
|
||||||
|
reduction = rewriter
|
||||||
|
.create<Torch::AtenMinDimOp>(loc, types, reduction,
|
||||||
|
dimValue, op.getKeepdim())
|
||||||
|
.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, reduction);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into
|
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into
|
||||||
// `AtenMinDimOp`
|
// `AtenMinDimOp`
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -6867,6 +6918,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
||||||
|
|
|
@ -77,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass());
|
||||||
|
|
|
@ -1472,6 +1472,62 @@ LTC_XFAIL_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_XFAIL_SET = {
|
ONNX_XFAIL_SET = {
|
||||||
|
# Failure - cast error
|
||||||
|
"MeanDimNoneDimModule_basic",
|
||||||
|
"MeanDtypeModule_basic",
|
||||||
|
"MeanDynamicSizesModule_basic",
|
||||||
|
"MeanModule_basic",
|
||||||
|
"MseLossMeanReductionModule_basic",
|
||||||
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
"StdBiasedModule_basic",
|
||||||
|
"VarBiasedModule_basic",
|
||||||
|
"VarMeanBiasedModule_basic",
|
||||||
|
|
||||||
|
# Failure - constant int lowering
|
||||||
|
"SplitTensorGetItem_Module_basic",
|
||||||
|
"SplitTensorLastSmallerModule_basic",
|
||||||
|
"SplitTensorListUnpackModule_basic",
|
||||||
|
"SplitTensorNegativeDimModule_basic",
|
||||||
|
"SplitWithSizesListUnpackModule_basic",
|
||||||
|
"UnbindIntGetItem_Module_basic",
|
||||||
|
"UnbindIntListUnpack_Module_basic",
|
||||||
|
|
||||||
|
# Failure - incorrect numerics
|
||||||
|
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||||
|
"ElementwiseAtan2TensorIntModule_basic",
|
||||||
|
"ElementwiseLog10IntModule_basic",
|
||||||
|
"ElementwiseLog2IntModule_basic",
|
||||||
|
"ElementwiseSeluModule_basic",
|
||||||
|
"FlipModuleStaticShape_basic",
|
||||||
|
"FlipNegativeIndexModule_basic",
|
||||||
|
"HardsigmoidModule_basic",
|
||||||
|
"HardsigmoidRandomModule_basic",
|
||||||
|
"IndexSelectDynamicInputSizeModule_basic",
|
||||||
|
"IndexSelectWholeDimensionModule_basic",
|
||||||
|
"IndexSelectWholeTensorModule_basic",
|
||||||
|
"IndexTensorStaticModule_basic",
|
||||||
|
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||||
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
|
"ResNet18Module_basic",
|
||||||
|
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||||
|
"SliceCopyNegative_Module_basic",
|
||||||
|
"SliceCopyNonZeroDim_Module_basic",
|
||||||
|
"SliceCopy_Module_basic",
|
||||||
|
"TupleModule_basic",
|
||||||
|
|
||||||
|
# Failure - incorrect shape
|
||||||
|
"ArangeStartOutDtypeModule_basic",
|
||||||
|
"ArangeStartOutViewModule_basic",
|
||||||
|
"BroadcastDynamicDimModule_basic",
|
||||||
|
"BroadcastToModule_basic",
|
||||||
|
"ExpandModule_basic",
|
||||||
|
"MoveDimIntNegativeIndexModule_basic",
|
||||||
|
"ReduceAmaxKeepDim_basic",
|
||||||
|
"ReduceMaxKeepDimReturnBoth_basic",
|
||||||
|
"ReduceMaxNegativeDim_basic",
|
||||||
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
|
||||||
# Failure - onnx_export
|
# Failure - onnx_export
|
||||||
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
@ -1594,6 +1650,7 @@ ONNX_XFAIL_SET = {
|
||||||
"EmptyStridedSizeIntStrideModule_basic",
|
"EmptyStridedSizeIntStrideModule_basic",
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"ExponentialModule_basic",
|
"ExponentialModule_basic",
|
||||||
|
"FloatImplicitModule_basic",
|
||||||
"GeFloatIntModule_basic",
|
"GeFloatIntModule_basic",
|
||||||
"GeFloatModule_basic",
|
"GeFloatModule_basic",
|
||||||
"GeIntModule_basic",
|
"GeIntModule_basic",
|
||||||
|
@ -1613,6 +1670,7 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImplIndexWithNoneModule_basic",
|
"IndexPutImplIndexWithNoneModule_basic",
|
||||||
"IntFloatModule_basic",
|
"IntFloatModule_basic",
|
||||||
|
"IntImplicitModule_basic",
|
||||||
"IouOfModule_basic",
|
"IouOfModule_basic",
|
||||||
"IsFloatingPointFloat_True",
|
"IsFloatingPointFloat_True",
|
||||||
"IsFloatingPointInt_False",
|
"IsFloatingPointInt_False",
|
||||||
|
@ -1818,13 +1876,8 @@ ONNX_XFAIL_SET = {
|
||||||
"_ConvolutionDeprecated2DCudnnModule_basic",
|
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||||
"_SoftmaxModule_basic",
|
"_SoftmaxModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_import
|
# Failure - onnx_import
|
||||||
"BucketizeTensorFloatModule_basic",
|
|
||||||
"BucketizeTensorModule_basic",
|
|
||||||
"BucketizeTensorOutInt32RightModule_basic",
|
|
||||||
"BucketizeTensorStaticFloatModule_basic",
|
|
||||||
"BucketizeTensorStaticModule_basic",
|
|
||||||
"DiagonalModule_basic",
|
"DiagonalModule_basic",
|
||||||
"DiagonalModule_nonsquare",
|
"DiagonalModule_nonsquare",
|
||||||
"DiagonalModule_transposed",
|
"DiagonalModule_transposed",
|
||||||
|
@ -1832,31 +1885,6 @@ ONNX_XFAIL_SET = {
|
||||||
"DiagonalModule_with_dims_and_offset",
|
"DiagonalModule_with_dims_and_offset",
|
||||||
"DiagonalModule_with_negative_dims",
|
"DiagonalModule_with_negative_dims",
|
||||||
"DiagonalModule_with_offset",
|
"DiagonalModule_with_offset",
|
||||||
"ElementwiseClampMaxModule_basic",
|
|
||||||
"ElementwiseClampMinModule_basic",
|
|
||||||
"ElementwiseClampMinTensorFloatModule_basic",
|
|
||||||
"ElementwiseClampMinTensorIntModule_basic",
|
|
||||||
"ElementwiseClampModule_basic",
|
|
||||||
"ElementwiseClampTensorFloatModule_basic",
|
|
||||||
"ElementwiseClampTensorInt8Module_basic",
|
|
||||||
"ElementwiseClampTensorIntModule_basic",
|
|
||||||
"HBC_basic",
|
|
||||||
"IndexPut1DFloatAccumulateModule_basic",
|
|
||||||
"IndexPut1DIntAccumulateModule_basic",
|
|
||||||
"IndexPut2DFloatAccumulateModule_basic",
|
|
||||||
"IndexPut2DIntAccumulateModule_basic",
|
|
||||||
"IndexPut3DFloatAccumulateModule_basic",
|
|
||||||
"IndexPut3DIntAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
|
||||||
"NormalizeModule_basic",
|
|
||||||
"PadWithNoneValModule_basic",
|
|
||||||
"QuantizedMLP_basic",
|
|
||||||
"RandModule_basic",
|
|
||||||
"ScatterReduceFloatMaxModuleIncludeSelf",
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
"ScatterReduceFloatMinModuleIncludeSelf",
|
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||||
"ScatterReduceFloatProdModuleIncludeSelf",
|
"ScatterReduceFloatProdModuleIncludeSelf",
|
||||||
|
@ -1867,21 +1895,11 @@ ONNX_XFAIL_SET = {
|
||||||
"ScatterReduceIntSumModuleIncludeSelf",
|
"ScatterReduceIntSumModuleIncludeSelf",
|
||||||
"TileBigDimsSizeModule_basic",
|
"TileBigDimsSizeModule_basic",
|
||||||
"TileSmallDimsSizeModule_basic",
|
"TileSmallDimsSizeModule_basic",
|
||||||
"UpSampleNearest2dDynamicSize_basic",
|
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
# Failure - onnx_lowering: onnx.AveragePool
|
||||||
|
|
||||||
# Failure - onnx_lowering
|
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
"AtenMmFloatTypes_basic",
|
|
||||||
"AtenMmIntTypes_basic",
|
|
||||||
"AtenTrilModule_basic",
|
|
||||||
"AtenTrilWithNegDiagonalModule_basic",
|
|
||||||
"AtenTrilWithPosDiagonalModule_basic",
|
|
||||||
"AtenTriuModule_basic",
|
|
||||||
"AtenTriuWithNegDiagonalModule_basic",
|
|
||||||
"AtenTriuWithPosDiagonalModule_basic",
|
|
||||||
"AvgPool1dFloatModule_basic",
|
"AvgPool1dFloatModule_basic",
|
||||||
"AvgPool1dIntModule_basic",
|
"AvgPool1dIntModule_basic",
|
||||||
"AvgPool1dStaticModule_basic",
|
"AvgPool1dStaticModule_basic",
|
||||||
|
@ -1890,78 +1908,73 @@ ONNX_XFAIL_SET = {
|
||||||
"AvgPool2dFloatModule_basic",
|
"AvgPool2dFloatModule_basic",
|
||||||
"AvgPool2dIntModule_basic",
|
"AvgPool2dIntModule_basic",
|
||||||
"AvgPool2dStaticModule_basic",
|
"AvgPool2dStaticModule_basic",
|
||||||
"BernoulliFloatModule_basic",
|
|
||||||
"BernoulliModule_basic",
|
# Failure - onnx_lowering: onnx.Cast
|
||||||
"BernoulliPModule_basic",
|
"BucketizeTensorOutInt32RightModule_basic",
|
||||||
"BernoulliTensorModule_basic",
|
"ElementwiseToDtypeI64ToI8Module_basic",
|
||||||
"ConstantPad2dStaticModule_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
"ConstantPadNdModule_basic",
|
"HBC_basic",
|
||||||
"ConstantPadNdPartialStaticModule_basic",
|
"QuantizedMLP_basic",
|
||||||
"ConstantPadNdStaticModule_basic",
|
"TypeConversionI1ToI32Module_basic",
|
||||||
"CrossEntropyLossModule_basic",
|
"TypeConversionI64ToI32Module_basic",
|
||||||
"CrossEntropyLossNoReductionModule_basic",
|
|
||||||
"DropoutTrainModule_basic",
|
# Failure - onnx_lowering: onnx.Clip
|
||||||
"DropoutTrainStaticShapeModule_basic",
|
"ElementwiseClampMaxModule_basic",
|
||||||
|
"ElementwiseClampMinModule_basic",
|
||||||
|
"ElementwiseClampMinTensorFloatModule_basic",
|
||||||
|
"ElementwiseClampMinTensorIntModule_basic",
|
||||||
|
"ElementwiseClampModule_basic",
|
||||||
|
"ElementwiseClampTensorFloatModule_basic",
|
||||||
|
"ElementwiseClampTensorInt8Module_basic",
|
||||||
|
"ElementwiseClampTensorIntModule_basic",
|
||||||
|
"NormalizeModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.Einsum
|
||||||
"EinsumStaticContractRhsModule_basic",
|
"EinsumStaticContractRhsModule_basic",
|
||||||
"EinsumStaticFourDimensionModule_basic",
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
"ElementwiseMishModule_basic",
|
|
||||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
# Failure - onnx_lowering: onnx.Gemm
|
||||||
"ElementwiseRemainderScalarModule_Int_basic",
|
"AtenMmFloatTypes_basic",
|
||||||
"ElementwiseToDtypeI64ToI8Module_basic",
|
"AtenMmIntTypes_basic",
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
|
||||||
"GroupNormModule_basic",
|
|
||||||
"GroupNormNoWeightAndBiasModule_basic",
|
|
||||||
"HardswishModule_basic",
|
|
||||||
"HardswishRandomModule_basic",
|
|
||||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPut2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut2DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPut3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPut3DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
|
||||||
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
|
||||||
"LogSoftmaxIntModule_basic",
|
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
|
||||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
|
||||||
"MaxPool2dWithIndicesStaticModule_basic",
|
|
||||||
"MmDagModule_basic",
|
"MmDagModule_basic",
|
||||||
"MmModule_basic",
|
"MmModule_basic",
|
||||||
"MmModule_chained",
|
"MmModule_chained",
|
||||||
"MmTanhModule_basic",
|
"MmTanhModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.HardSwish
|
||||||
|
"HardswishModule_basic",
|
||||||
|
"HardswishRandomModule_basic",
|
||||||
"MobilenetV3Module_basic",
|
"MobilenetV3Module_basic",
|
||||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
|
||||||
"NativeDropoutTrainModule_basic",
|
# Failure - onnx_lowering: onnx.LogSoftmax
|
||||||
"NativeDropoutTrainStaticShapeModule_basic",
|
"LogSoftmaxIntModule_basic",
|
||||||
|
"_LogSoftmaxModuleStable_basic",
|
||||||
|
"_LogSoftmaxModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"MaxPool2dWithIndicesStaticModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.Mod
|
||||||
|
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_basic",
|
||||||
|
"UnflattenIntNegativeOneDimStaticModule_basic",
|
||||||
|
"UnflattenIntNegativeOneSizeStaticModule_basic",
|
||||||
|
"UnflattenIntStaticModule_basic",
|
||||||
|
"UnflattenStaticModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.OneHot
|
||||||
"OneHotModule_basic",
|
"OneHotModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.Pad
|
||||||
|
"ConstantPad2dStaticModule_basic",
|
||||||
|
"ConstantPadNdModule_basic",
|
||||||
|
"ConstantPadNdPartialStaticModule_basic",
|
||||||
|
"ConstantPadNdStaticModule_basic",
|
||||||
"PadModule_basic",
|
"PadModule_basic",
|
||||||
"RandIntLowDtypeModule_basic",
|
"PadWithNoneValModule_basic",
|
||||||
"RandIntLowModule_basic",
|
|
||||||
"RandLikeDtypeModule_basic",
|
|
||||||
"RandLikeModule_basic",
|
|
||||||
"RandnDtypeDeviceModule_basic",
|
|
||||||
"RandnGeneratorF64Module_basic",
|
|
||||||
"RandnGeneratorModule_basic",
|
|
||||||
"RandnLikeDtypeModule_basic",
|
|
||||||
"RandnLikeModule_basic",
|
|
||||||
"RandnModule_basic",
|
|
||||||
"ReduceL1NormModule_basic",
|
|
||||||
"ReduceL1NormWithDTypeModule_basic",
|
|
||||||
"ReduceL2NormModule_basic",
|
|
||||||
"ReduceL3NormAllDimsModule_basic",
|
|
||||||
"ReduceL3NormKeepDimModule_basic",
|
|
||||||
"ReduceProdDimIntFloatModule_basic",
|
|
||||||
"ReduceSumDtypeFloatModule_basic",
|
|
||||||
"ReduceSumDtypeIntModule_basic",
|
|
||||||
"ReduceSumElementTypeBoolModule_basic",
|
|
||||||
"ReduceSumFloatModule_basic",
|
|
||||||
"ReduceSumSignedIntModule_basic",
|
|
||||||
"ReduceSumUnsignedIntModule_basic",
|
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
"ReflectionPad1dModule3dInput_Left",
|
"ReflectionPad1dModule3dInput_Left",
|
||||||
|
@ -1976,19 +1989,43 @@ ONNX_XFAIL_SET = {
|
||||||
"ReplicationPad2dModule_left0",
|
"ReplicationPad2dModule_left0",
|
||||||
"ReplicationPad2dModule_right0",
|
"ReplicationPad2dModule_right0",
|
||||||
"ReplicationPad2dModule_top0",
|
"ReplicationPad2dModule_top0",
|
||||||
"ScatterSrcModule_basic",
|
|
||||||
"ScatterSrcStaticModule_basic",
|
# Failure - onnx_lowering: onnx.RandomNormal
|
||||||
"ScatterValueFloatModule_basic",
|
"RandnDtypeDeviceModule_basic",
|
||||||
"ScatterValueIntModule_basic",
|
"RandnGeneratorF64Module_basic",
|
||||||
"SoftplusModule_basic",
|
"RandnGeneratorModule_basic",
|
||||||
"SortTensorDescending_basic",
|
"RandnModule_basic",
|
||||||
"SortTensorInteger_basic",
|
|
||||||
"SortTensorNegativeDimension_basic",
|
# Failure - onnx_lowering: onnx.RandomNormalLike
|
||||||
"SortTensorSpecificDimension_basic",
|
"RandnLikeDtypeModule_basic",
|
||||||
"SortTensor_basic",
|
"RandnLikeModule_basic",
|
||||||
"SqueezeModule_allUnitDim",
|
|
||||||
"SqueezeModule_broadcast",
|
# Failure - onnx_lowering: onnx.RandomUniform
|
||||||
"SqueezeModule_static",
|
"RandIntLowDtypeModule_basic",
|
||||||
|
"RandIntLowModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.RandomUniformLike
|
||||||
|
"BernoulliFloatModule_basic",
|
||||||
|
"BernoulliPModule_basic",
|
||||||
|
"BernoulliTensorModule_basic",
|
||||||
|
"RandLikeDtypeModule_basic",
|
||||||
|
"RandLikeModule_basic",
|
||||||
|
"RandModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.ReduceL1
|
||||||
|
"ReduceL1NormModule_basic",
|
||||||
|
"ReduceL1NormWithDTypeModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.ReduceL2
|
||||||
|
"ReduceL2NormModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.ReduceProd
|
||||||
|
"BernoulliModule_basic",
|
||||||
|
"DropoutTrainModule_basic",
|
||||||
|
"DropoutTrainStaticShapeModule_basic",
|
||||||
|
"NativeDropoutTrainModule_basic",
|
||||||
|
"NativeDropoutTrainStaticShapeModule_basic",
|
||||||
|
"ReduceProdDimIntFloatModule_basic",
|
||||||
"StdCorrectionAllDimReduceModule_basic",
|
"StdCorrectionAllDimReduceModule_basic",
|
||||||
"StdCorrectionKeepDimModule_basic",
|
"StdCorrectionKeepDimModule_basic",
|
||||||
"StdCorrectionLargeInputModule_basic",
|
"StdCorrectionLargeInputModule_basic",
|
||||||
|
@ -1999,14 +2036,6 @@ ONNX_XFAIL_SET = {
|
||||||
"StdDimKeepDimTrueModule_basic",
|
"StdDimKeepDimTrueModule_basic",
|
||||||
"StdDimNoneDimModule_basic",
|
"StdDimNoneDimModule_basic",
|
||||||
"StdUnbiasedModule_basic",
|
"StdUnbiasedModule_basic",
|
||||||
"TriuBroadcastModule_basic",
|
|
||||||
"TriuModule_basic",
|
|
||||||
"TypeConversionI1ToI32Module_basic",
|
|
||||||
"TypeConversionI64ToI32Module_basic",
|
|
||||||
"UnflattenIntNegativeOneDimStaticModule_basic",
|
|
||||||
"UnflattenIntNegativeOneSizeStaticModule_basic",
|
|
||||||
"UnflattenIntStaticModule_basic",
|
|
||||||
"UnflattenStaticModule_basic",
|
|
||||||
"VarCorrectionAllDimReduceModule_basic",
|
"VarCorrectionAllDimReduceModule_basic",
|
||||||
"VarCorrectionKeepDimModule_basic",
|
"VarCorrectionKeepDimModule_basic",
|
||||||
"VarCorrectionLargeInputModule_basic",
|
"VarCorrectionLargeInputModule_basic",
|
||||||
|
@ -2025,58 +2054,85 @@ ONNX_XFAIL_SET = {
|
||||||
"VarMeanDimModule_basic",
|
"VarMeanDimModule_basic",
|
||||||
"VarMeanUnbiasedModule_basic",
|
"VarMeanUnbiasedModule_basic",
|
||||||
"VarUnbiasedModule_basic",
|
"VarUnbiasedModule_basic",
|
||||||
"_LogSoftmaxModuleStable_basic",
|
|
||||||
"_LogSoftmaxModule_basic",
|
# Failure - onnx_lowering: onnx.ReduceSum
|
||||||
|
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||||
# Failure - cast_error
|
"ReduceL3NormAllDimsModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
"MeanDtypeModule_basic",
|
"ReduceSumDtypeFloatModule_basic",
|
||||||
"MeanDynamicSizesModule_basic",
|
"ReduceSumDtypeIntModule_basic",
|
||||||
"MeanModule_basic",
|
"ReduceSumElementTypeBoolModule_basic",
|
||||||
"MseLossMeanReductionModule_basic",
|
"ReduceSumFloatModule_basic",
|
||||||
"StdBiasedModule_basic",
|
"ReduceSumSignedIntModule_basic",
|
||||||
"VarBiasedModule_basic",
|
"ReduceSumUnsignedIntModule_basic",
|
||||||
"VarMeanBiasedModule_basic",
|
|
||||||
|
# Failure - onnx_lowering: onnx.Resize
|
||||||
# Failure - constant_int
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
"ReduceMinAlongDimNegative_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
"ReduceMinAlongDimSignedInt_basic",
|
|
||||||
"ReduceMinAlongDim_basic",
|
# Failure - onnx_lowering: onnx.ScatterElements
|
||||||
"ReduceMinFloatModule_basic",
|
"ScatterSrcModule_basic",
|
||||||
"ReduceMinKeepDimReturnBoth_basic",
|
"ScatterSrcStaticModule_basic",
|
||||||
"ReduceMinSignedIntModule_basic",
|
"ScatterValueFloatModule_basic",
|
||||||
"ReduceMinUnsignedIntModule_basic",
|
"ScatterValueIntModule_basic",
|
||||||
"SplitTensorGetItem_Module_basic",
|
|
||||||
"SplitTensorLastSmallerModule_basic",
|
# Failure - onnx_lowering: onnx.ScatterND
|
||||||
"SplitTensorListUnpackModule_basic",
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
"SplitTensorNegativeDimModule_basic",
|
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||||
"SplitWithSizesListUnpackModule_basic",
|
"IndexPut1DIntAccumulateModule_basic",
|
||||||
"UnbindIntGetItem_Module_basic",
|
"IndexPut1DIntNonAccumulateModule_basic",
|
||||||
"UnbindIntListUnpack_Module_basic",
|
"IndexPut2DFloatAccumulateModule_basic",
|
||||||
|
"IndexPut2DFloatNonAccumulateModule_basic",
|
||||||
# Failure - operand_type
|
"IndexPut2DIntAccumulateModule_basic",
|
||||||
"ElementwiseAcosIntModule_basic",
|
"IndexPut2DIntNonAccumulateModule_basic",
|
||||||
"ElementwiseAsinIntModule_basic",
|
"IndexPut3DFloatAccumulateModule_basic",
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
"IndexPut3DFloatNonAccumulateModule_basic",
|
||||||
"ElementwiseCosIntModule_basic",
|
"IndexPut3DIntAccumulateModule_basic",
|
||||||
"ElementwiseErfIntModule_basic",
|
"IndexPut3DIntNonAccumulateModule_basic",
|
||||||
"ElementwiseExpIntModule_basic",
|
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
||||||
"ElementwiseLog10IntModule_basic",
|
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
||||||
"ElementwiseLogIntModule_basic",
|
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
||||||
"ElementwiseSinIntModule_basic",
|
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
||||||
"ElementwiseTanIntModule_basic",
|
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
||||||
"ElementwiseUnaryIntModule_basic",
|
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
||||||
# Failure - expand_multidim
|
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
||||||
"IndexTensorHackedTwinModule3dInput_basic",
|
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||||
"IndexTensorHackedTwinModule_basic",
|
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||||
"IndexTensorModule3dInput_basic",
|
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||||
"IndexTensorModule_basic",
|
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
|
||||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
"CrossEntropyLossModule_basic",
|
||||||
|
"CrossEntropyLossNoReductionModule_basic",
|
||||||
# Failure - rankless_return
|
|
||||||
|
# Failure - onnx_lowering: onnx.Softplus
|
||||||
|
"ElementwiseMishModule_basic",
|
||||||
|
"SoftplusModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.Squeeze
|
||||||
|
"SqueezeModule_allUnitDim",
|
||||||
|
"SqueezeModule_broadcast",
|
||||||
|
"SqueezeModule_static",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.TopK
|
||||||
|
"SortTensorDescending_basic",
|
||||||
|
"SortTensorInteger_basic",
|
||||||
|
"SortTensorNegativeDimension_basic",
|
||||||
|
"SortTensorSpecificDimension_basic",
|
||||||
|
"SortTensor_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering: onnx.Trilu
|
||||||
|
"AtenTrilModule_basic",
|
||||||
|
"AtenTrilWithNegDiagonalModule_basic",
|
||||||
|
"AtenTrilWithPosDiagonalModule_basic",
|
||||||
|
"AtenTriuModule_basic",
|
||||||
|
"AtenTriuWithNegDiagonalModule_basic",
|
||||||
|
"AtenTriuWithPosDiagonalModule_basic",
|
||||||
|
"TriuBroadcastModule_basic",
|
||||||
|
"TriuModule_basic",
|
||||||
|
|
||||||
|
# Failure - rankless return
|
||||||
"ReduceAmaxMultiDim_basic",
|
"ReduceAmaxMultiDim_basic",
|
||||||
"ReduceAmaxOutOfOrderDim_basic",
|
"ReduceAmaxOutOfOrderDim_basic",
|
||||||
"ReduceAmaxSingleDim_basic",
|
"ReduceAmaxSingleDim_basic",
|
||||||
|
@ -2088,8 +2144,8 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceMaxFloatModule_basic",
|
"ReduceMaxFloatModule_basic",
|
||||||
"ReduceMaxSignedIntModule_basic",
|
"ReduceMaxSignedIntModule_basic",
|
||||||
"ReduceMaxUnsignedIntModule_basic",
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
|
|
||||||
# Failure - view_lowering
|
# Failure - torch.aten.view lower
|
||||||
"AddSizeIntModule_basic",
|
"AddSizeIntModule_basic",
|
||||||
"ElementwiseFlattenBroadcastModule_basic",
|
"ElementwiseFlattenBroadcastModule_basic",
|
||||||
"FlattenRank0Module_basic",
|
"FlattenRank0Module_basic",
|
||||||
|
@ -2097,13 +2153,11 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
"IndexTensorMultiInputContiguousCenter_basic",
|
"IndexTensorMultiInputContiguousCenter_basic",
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
"IndexTensorMultiInputNonContiguous_basic",
|
"IndexTensorMultiInputNonContiguous_basic",
|
||||||
"IndexTensorMultiInputOneDim_basic",
|
"IndexTensorMultiInputOneDim_basic",
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
"IndexTensorMultiInputThreeIndexers_basic",
|
||||||
"IndexTensorMultiInput_basic",
|
"IndexTensorMultiInput_basic",
|
||||||
"IndexTensorSelectDimModule_basic",
|
|
||||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||||
"RepeatModule_basic",
|
"RepeatModule_basic",
|
||||||
"SelectIntModule_basic",
|
"SelectIntModule_basic",
|
||||||
|
@ -2116,63 +2170,50 @@ ONNX_XFAIL_SET = {
|
||||||
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
||||||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||||
"ViewSizeDimLedByExpandedOnesModule_basic",
|
"ViewSizeDimLedByExpandedOnesModule_basic",
|
||||||
|
|
||||||
# Failure - numerical
|
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
|
||||||
"ElementwiseSeluModule_basic",
|
|
||||||
"EmbeddingModule1DIndices_basic",
|
|
||||||
"FlipNegativeIndexModule_basic",
|
|
||||||
"HardsigmoidModule_basic",
|
|
||||||
"HardsigmoidRandomModule_basic",
|
|
||||||
"IndexSelectDynamicIndexSizeModule_basic",
|
|
||||||
"IndexSelectDynamicInputSizeModule_basic",
|
|
||||||
"IndexSelectDynamicModulebasic",
|
|
||||||
"IndexSelectWholeDimensionModule_basic",
|
|
||||||
"IndexSelectWholeTensorModule_basic",
|
|
||||||
"IndexTensorStaticModule_basic",
|
|
||||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
|
||||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
|
||||||
"ResNet18Module_basic",
|
|
||||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
|
||||||
"SliceCopyNegative_Module_basic",
|
|
||||||
"SliceCopyNonZeroDim_Module_basic",
|
|
||||||
"SliceCopy_Module_basic",
|
|
||||||
"TupleModule_basic",
|
|
||||||
|
|
||||||
# Failure - shape
|
|
||||||
"ArangeStartOutDtypeModule_basic",
|
|
||||||
"ArangeStartOutViewModule_basic",
|
|
||||||
"BroadcastDynamicDimModule_basic",
|
|
||||||
"BroadcastToModule_basic",
|
|
||||||
"EmbeddingModuleF16_basic",
|
|
||||||
"EmbeddingModuleI32_basic",
|
|
||||||
"EmbeddingModuleI64_basic",
|
|
||||||
"ExpandModule_basic",
|
|
||||||
"MoveDimIntNegativeIndexModule_basic",
|
|
||||||
"PermuteNegativeIndexModule_basic",
|
|
||||||
"ReduceAmaxKeepDim_basic",
|
|
||||||
"ReduceMaxKeepDimReturnBoth_basic",
|
|
||||||
"ReduceMaxNegativeDim_basic",
|
|
||||||
"ViewSizeFromOtherTensor_basic",
|
|
||||||
|
|
||||||
# Failure - onnx traces differently
|
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
|
||||||
|
|
||||||
# Failure - unknown
|
# Failure - unknown
|
||||||
|
"BucketizeTensorFloatModule_basic",
|
||||||
|
"BucketizeTensorModule_basic",
|
||||||
|
"BucketizeTensorStaticFloatModule_basic",
|
||||||
|
"BucketizeTensorStaticModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
"CopyWithDifferentDTypesAndSizesModule_basic",
|
"CopyWithDifferentDTypesAndSizesModule_basic",
|
||||||
"CopyWithDifferentDTypesModule_basic",
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
"CosineSimilarityStaticBroadcastModule_basic",
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
"CumsumInputDtypeInt32Module_basic",
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
"ElementwiseAtan2TensorIntModule_basic",
|
"ElementwiseAcosIntModule_basic",
|
||||||
|
"ElementwiseAsinIntModule_basic",
|
||||||
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
|
"ElementwiseCosIntModule_basic",
|
||||||
"ElementwiseDivRoundingModeTruncModule_basic",
|
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||||
|
"ElementwiseErfIntModule_basic",
|
||||||
|
"ElementwiseExpIntModule_basic",
|
||||||
|
"ElementwiseLogIntModule_basic",
|
||||||
"ElementwisePreluModule_basic",
|
"ElementwisePreluModule_basic",
|
||||||
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
|
"ElementwiseSinIntModule_basic",
|
||||||
|
"ElementwiseTanIntModule_basic",
|
||||||
|
"ElementwiseUnaryIntModule_basic",
|
||||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||||
"ElementwiseWhereScalarModule_basic",
|
"ElementwiseWhereScalarModule_basic",
|
||||||
|
"EmbeddingModule1DIndices_basic",
|
||||||
|
"EmbeddingModuleF16_basic",
|
||||||
|
"EmbeddingModuleI32_basic",
|
||||||
|
"EmbeddingModuleI64_basic",
|
||||||
"FlattenDynamicModule_basic",
|
"FlattenDynamicModule_basic",
|
||||||
"FlipModuleStaticShape_basic",
|
|
||||||
"GluStaticModule_basic",
|
"GluStaticModule_basic",
|
||||||
|
"GroupNormModule_basic",
|
||||||
|
"GroupNormNoWeightAndBiasModule_basic",
|
||||||
|
"IndexSelectDynamicIndexSizeModule_basic",
|
||||||
|
"IndexSelectDynamicModulebasic",
|
||||||
|
"IndexTensorHackedTwinModule3dInput_basic",
|
||||||
|
"IndexTensorHackedTwinModule_basic",
|
||||||
|
"IndexTensorModule3dInput_basic",
|
||||||
|
"IndexTensorModule_basic",
|
||||||
|
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||||
|
"IndexTensorSelectDimModule_basic",
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"ReduceAllDimEmpty_basic",
|
"ReduceAllDimEmpty_basic",
|
||||||
"ReduceAllDimFloat_basic",
|
"ReduceAllDimFloat_basic",
|
||||||
|
@ -2180,8 +2221,6 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
"TensorsStackNegativeDimModule_basic",
|
"TensorsStackNegativeDimModule_basic",
|
||||||
"TensorsStackPromoteDTypeModule_basic",
|
"TensorsStackPromoteDTypeModule_basic",
|
||||||
"FloatImplicitModule_basic",
|
|
||||||
"IntImplicitModule_basic",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_CRASHING_SET = { }
|
ONNX_CRASHING_SET = { }
|
||||||
|
|
|
@ -926,107 +926,121 @@ func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs
|
// CHECK-LABEL: func.func @test_reduce_min_empty_set_fp
|
||||||
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reduce_min_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
|
||||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
|
||||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
|
||||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
|
||||||
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int
|
// CHECK: return %[[FULL]]
|
||||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
|
||||||
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
|
|
||||||
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
|
|
||||||
return %0 : !torch.vtensor<[4,1],i1>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example
|
|
||||||
func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
|
||||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
|
||||||
// CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
|
||||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
|
||||||
// CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1],f32>
|
|
||||||
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32>
|
|
||||||
return %0 : !torch.vtensor<[1,1,1],f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example
|
|
||||||
func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
|
||||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
|
||||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
|
||||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
|
||||||
// CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,2],f32>
|
|
||||||
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
|
|
||||||
return %0 : !torch.vtensor<[3,2],f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_min_empty_set
|
|
||||||
func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
|
||||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
|
||||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
|
||||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
|
||||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
|
||||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
|
||||||
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,1,4],f32>
|
|
||||||
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
|
||||||
return %0 : !torch.vtensor<[2,1,4],f32>
|
return %0 : !torch.vtensor<[2,1,4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_min_keepdims_example
|
// -----
|
||||||
func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK-LABEL: func.func @test_reduce_min_empty_set_int
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
|
||||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
|
||||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
// CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
|
||||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
|
||||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
// CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
|
||||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
// CHECK: return %[[FULL]]
|
||||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32>
|
||||||
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,2],f32>
|
return %0 : !torch.vtensor<[2,1,4],si32>
|
||||||
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
|
||||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example
|
// -----
|
||||||
func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
|
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||||
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||||
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
|
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||||
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
|
||||||
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,2],f32>
|
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
|
||||||
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
|
// CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1>
|
||||||
return %0 : !torch.vtensor<[3,1,2],f32>
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
|
||||||
|
return %0 : !torch.vtensor<[4,1],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims
|
||||||
|
func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SZ:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
|
||||||
|
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||||
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||||
|
// CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1>
|
||||||
|
return %0 : !torch.vtensor<[4],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_reduce_all_dims_default
|
||||||
|
func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[I0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
// CHECK: %[[C0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]]
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[],i1>
|
||||||
|
// CHECK: return %[[MIN]] : !torch.vtensor<[],i1>
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1>
|
||||||
|
return %0 : !torch.vtensor<[],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
|
||||||
|
// CHECK: return %[[AMIN]]
|
||||||
|
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1>
|
||||||
|
return %0 : !torch.vtensor<[4],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue