[onnx] Fix ReduceMean lowering to torch (#2956)

Torch lowering only supported the most recent version. Refactored the
lowering so more easily handle default values and optional operands /
attributes.
pull/2960/head
Rob Suderman 2024-02-27 22:48:07 -08:00 committed by GitHub
parent d541779f37
commit 4a7a7d76f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 618 additions and 459 deletions

View File

@ -1104,97 +1104,92 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value axes; Value axes;
int64_t keepDims; int64_t keepDims;
int64_t noop_with_empty_axes; int64_t noop_with_empty_axes;
// Deal with case when no axes arg is passed if (binder.tensorOperandAtIndex(data, 0) ||
if (binder.op->getNumOperands() == 1) {
if (binder.tensorOperand(data) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes,
"noop_with_empty_axes", 0))
return failure();
if (noop_with_empty_axes == 0) {
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims));
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), keepDimsConstInt);
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
SmallVector<Value> axesList;
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
axesList.push_back(curr);
}
Value axesValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
axesList);
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
binder.op, resultType, data, axesValueList, keepDimsBool);
} else {
rewriter.replaceOp(binder.op, data);
}
return success();
}
if (binder.tensorOperands(data, axes) ||
binder.tensorResultType(resultType) || binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(keepDims, "keepdims", 1) ||
binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes",
0)) 0))
return failure(); return failure();
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>(); auto dataTy = cast<Torch::BaseTensorType>(data.getType());
SmallVector<Value> dimList; Torch::IntType torchIntTy = rewriter.getType<Torch::IntType>();
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1); // If any of the input dims are 0 we set to the upper limit:
Type selectResultType = axesType.getWithSizesAndDtype( if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) &&
llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); (llvm::any_of(dataTy.getSizes(),
auto sizes = [](int64_t d) { return d == Torch::kUnknownSize; }) ||
dyn_cast<Torch::ValueTensorType>(axes.getType()).getSizes(); keepDims)) {
// deal with case when axes is empty auto dty = dataTy.getDtype();
if (sizes.size() == 1 && sizes[0] == 0) { Value scalar;
if (noop_with_empty_axes == 0) { if (FloatType fpTy = dyn_cast<FloatType>(dty)) {
// create dims list with all dims [0, data.getSizes().size()) auto inf = APFloat::getInf(fpTy.getFloatSemantics());
Value keepDimsConstInt = rewriter.create<Torch::ConstantIntOp>( scalar = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); rewriter.getFloatAttr(rewriter.getF64Type(),
Value keepDimsBool = rewriter.create<Torch::AtenBoolIntOp>( inf.convertToDouble()));
binder.getLoc(), keepDimsConstInt);
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
dimList.push_back(curr);
} }
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), if (IntegerType intTy = dyn_cast<IntegerType>(dty)) {
Torch::ListType::get( auto mx =
Torch::IntType::get(binder.op->getContext())), intTy.isSigned()
dimList); ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>( : APInt::getMaxValue(intTy.getIntOrFloatBitWidth());
binder.op, resultType, data, dimValueList, keepDimsBool); scalar = rewriter.create<Torch::ConstantIntOp>(
} else { binder.getLoc(), torchIntTy,
rewriter.replaceOp(binder.op, data); rewriter.getIntegerAttr(rewriter.getIntegerType(64),
mx.getSExtValue()));
} }
llvm::SmallVector<Value> fillDims;
for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) {
auto staticDim = resultType.getSizes()[i];
if (staticDim != Torch::kUnknownSize) {
fillDims.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getI64IntegerAttr(staticDim)));
continue;
}
Value iv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
fillDims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), torchIntTy, data, iv));
}
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value fillDimsList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims);
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
binder.op, resultType, fillDimsList, scalar, none, none, none,
none);
return success(); return success();
} }
// Previous version of the operation had the axes as an attribute:
SmallVector<Value> axesList;
llvm::SmallVector<int64_t> axesAttr;
if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) {
for (int i = 0, s = axesAttr.size(); i < s; ++i) {
axesList.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), torchIntTy,
rewriter.getI64IntegerAttr(axesAttr[i])));
}
}
// Extract the axes values from the axes operand:
if (!binder.tensorOperandAtIndex(axes, 1)) {
Torch::BaseTensorType axesType =
axes.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes{1};
Type selectResultType = axesType.getWithSizesAndDtype(
selectSizes, axesType.getOptionalDtype());
auto sizes = axesType.getSizes();
Value zero = rewriter.create<Torch::ConstantIntOp>( Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
int64_t adjustmentInt =
cast<Torch::ValueTensorType>(data.getType()).getSizes().size(); // Extract the value of each axes:
Value adjustment = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
adjustmentInt));
// convert axes (tensor) into torch int list while dealing with neg axis
for (int i = 0; i < sizes[0]; i++) { for (int i = 0; i < sizes[0]; i++) {
// Go through the axes list and get each dim in the list // Go through the axes list and get each dim in the list
Value selectIndex = rewriter.create<Torch::ConstantIntOp>( Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
@ -1204,29 +1199,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), selectResultType, axes, zero, selectIndex); binder.getLoc(), selectResultType, axes, zero, selectIndex);
Value dim = rewriter.create<Torch::AtenItemOp>( Value dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract); binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
// deal with neg axis: if (axis < 0) axis += rank axesList.push_back(dim);
}
}
// Handle the noop case:
if (axesList.empty() && noop_with_empty_axes) {
rewriter.replaceOp(binder.op, data);
return success();
}
// Deal with case when no axes arg is passed but not a noop:
if (axesList.empty()) {
int64_t numDims = dyn_cast<Torch::ValueTensorType>(data.getType())
.getSizes()
.size();
for (int i = 0; i < numDims; i++) {
Value curr = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
axesList.push_back(curr);
}
}
// Handle negative axis:
Value rankVal = rewriter.create<Torch::AtenDimOp>(binder.getLoc(),
torchIntTy, data);
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(0));
for (Value &axes : axesList) {
Value isNegative = Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), dim, zero); rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axes, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative); isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>( Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, adjustment); binder.getLoc(), isNegative, rankVal);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>( axes = rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), axes,
binder.getLoc(), dim, finalOffset); finalOffset);
dimList.push_back(finalDim);
} }
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>( Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), binder.getLoc(), Torch::ListType::get(torchIntTy), axesList);
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), Value keepDimBool =
dimList); rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), keepDims);
Value keepDimBool;
if (keepDims == 1) {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
} else {
keepDimBool =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
}
rewriter.replaceOpWithNewOp<Torch::AtenAminOp>( rewriter.replaceOpWithNewOp<Torch::AtenAminOp>(
binder.op, resultType, data, dimValueList, keepDimBool); binder.op, resultType, data, dimValueList, keepDimBool);
return success(); return success();

View File

@ -60,18 +60,15 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
RankedTensorType valResultType = auto typec = this->getTypeConverter();
getTypeConverter() auto valResultType =
->convertType(op.getResult(0).getType()) cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
.template cast<RankedTensorType>(); auto idxResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
RankedTensorType idxResultType =
this->getTypeConverter()
->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>();
RankedTensorType inputType = RankedTensorType inputType =
input.getType().template cast<RankedTensorType>(); input.getType().template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType(); Type idxElementType =
getElementTypeOrSelf(typec->convertType(idxResultType));
if (!idxElementType.isa<IntegerType>()) if (!idxElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, opName + " to linalg.* requires integer-like result type"); op, opName + " to linalg.* requires integer-like result type");
@ -109,14 +106,12 @@ public:
} }
// Constant op to account for the reduction along dim. // Constant op to account for the reduction along dim.
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape; SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) { for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) { if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i); auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
resultShape.push_back(currentDimSize); resultShape.push_back(currentDimSize);
} else if (keepDim) }
resultShape.push_back(c1);
} }
// First fill the output buffer for the index. // First fill the output buffer for the index.
Value filledTensorIdx = Value filledTensorIdx =
@ -146,27 +141,23 @@ public:
Value filledTensorVal = Value filledTensorVal =
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result(); rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();
SmallVector<utils::IteratorType> iteratorTypes(
inputType.getRank(), utils::IteratorType::parallel);
iteratorTypes[dim] = utils::IteratorType::reduction;
// Create the affine expressions that will be used to // Create the affine expressions that will be used to
// iterate over the input and output tensors. // iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction. // Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs; SmallVector<AffineExpr> exprs;
SmallVector<utils::IteratorType> iteratorTypes;
SmallVector<AffineExpr> resultExprs; SmallVector<AffineExpr> resultExprs;
for (auto size : for (auto size :
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
exprs.push_back(rewriter.getAffineDimExpr(size.index())); exprs.push_back(rewriter.getAffineDimExpr(size.index()));
if (unsigned(dim) != size.index())
if (unsigned(dim) == size.index()) {
iteratorTypes.push_back(utils::IteratorType::reduction);
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
resultExprs.push_back(rewriter.getAffineConstantExpr(0));
} else {
iteratorTypes.push_back(utils::IteratorType::parallel);
resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
} }
}
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs},
rewriter.getContext()); rewriter.getContext());
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
@ -219,12 +210,58 @@ public:
nestedLoc, ValueRange({resultVal, resultIndex})); nestedLoc, ValueRange({resultVal, resultIndex}));
}); });
// This cast is required to fix the shape in the case of keepDim=True if (!keepDim) {
Value valuesCast = rewriter.create<tensor::CastOp>(loc, valResultType, Value rVal = rewriter.create<tensor::CastOp>(loc, valResultType,
linalgOp.getResult(0)); linalgOp.getResult(0));
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType, Value rIdx = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1)); linalgOp.getResult(1));
rewriter.replaceOp(op, {valuesCast, idxCast}); llvm::SmallVector<Value> res{rVal, rIdx};
rewriter.replaceOp(op, res);
return success();
}
llvm::SmallVector<int64_t> valShape(valResultType.getShape());
llvm::SmallVector<int64_t> idxShape(idxResultType.getShape());
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
valShape[i] = valShape[i + 1];
idxShape[i] = idxShape[i + 1];
}
valShape.resize(valShape.size() - 1);
idxShape.resize(idxShape.size() - 1);
Value rVal = rewriter.create<tensor::CastOp>(
loc, valResultType.clone(valShape), linalgOp.getResult(0));
Value rIdx = rewriter.create<tensor::CastOp>(
loc, idxResultType.clone(idxShape), linalgOp.getResult(1));
SmallVector<ReassociationIndices> reassociation(valShape.size());
if (reassociation.size() > 0) {
for (int i = 0; i < dim; ++i)
reassociation[i].push_back(i);
reassociation[std::max<int64_t>(0, dim - 1)].push_back(dim);
for (int i = dim, s = reassociation.size(); i < s; ++i)
reassociation[i].push_back(i + 1);
}
valShape.push_back(0);
idxShape.push_back(0);
for (int i = dim, s = valShape.size() - 1; i < s; ++i) {
valShape[i + 1] = valShape[i];
idxShape[i + 1] = idxShape[i];
}
valShape[dim] = 1;
idxShape[dim] = 1;
Value unsqueezeVal = rewriter.create<tensor::ExpandShapeOp>(
loc, valResultType, rVal, reassociation);
Value unsqueezeIdx = rewriter.create<tensor::ExpandShapeOp>(
loc, idxResultType, rIdx, reassociation);
llvm::SmallVector<Value> unsqueezes = {unsqueezeVal, unsqueezeIdx};
rewriter.replaceOp(op, unsqueezes);
return success(); return success();
} }
}; };

View File

@ -1316,6 +1316,57 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> {
public:
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Torch::AtenAminOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<int64_t> dimList;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
return rewriter.notifyMatchFailure(op, "dims not foldable constants");
}
bool keepdim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) {
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants");
}
auto loc = op.getLoc();
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>());
Value reduction = op.getSelf();
auto resultTy = cast<Torch::ValueTensorType>(op.getType());
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType());
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes());
for (auto dim : dimList) {
auto dimValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim));
reductionShape[dim] = 1;
if (!keepdim) {
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i)
reductionShape[i] = reductionShape[i + 1];
reductionShape.resize(reductionShape.size() - 1);
}
reductionTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, resultTy.getOptionalDtype());
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};
reduction = rewriter
.create<Torch::AtenMinDimOp>(loc, types, reduction,
dimValue, op.getKeepdim())
.getResult(0);
}
rewriter.replaceOp(op, reduction);
return success();
}
};
} // namespace
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into // Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into
// `AtenMinDimOp` // `AtenMinDimOp`
namespace { namespace {
@ -6867,6 +6918,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);

View File

@ -77,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass());

View File

@ -1472,6 +1472,62 @@ LTC_XFAIL_SET = {
} }
ONNX_XFAIL_SET = { ONNX_XFAIL_SET = {
# Failure - cast error
"MeanDimNoneDimModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanModule_basic",
"MseLossMeanReductionModule_basic",
"PermuteNegativeIndexModule_basic",
"StdBiasedModule_basic",
"VarBiasedModule_basic",
"VarMeanBiasedModule_basic",
# Failure - constant int lowering
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
# Failure - incorrect numerics
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseSeluModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ResNet18Module_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
"SliceCopy_Module_basic",
"TupleModule_basic",
# Failure - incorrect shape
"ArangeStartOutDtypeModule_basic",
"ArangeStartOutViewModule_basic",
"BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"ExpandModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"ReduceAmaxKeepDim_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ViewSizeFromOtherTensor_basic",
# Failure - onnx_export # Failure - onnx_export
"AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
@ -1594,6 +1650,7 @@ ONNX_XFAIL_SET = {
"EmptyStridedSizeIntStrideModule_basic", "EmptyStridedSizeIntStrideModule_basic",
"EqIntModule_basic", "EqIntModule_basic",
"ExponentialModule_basic", "ExponentialModule_basic",
"FloatImplicitModule_basic",
"GeFloatIntModule_basic", "GeFloatIntModule_basic",
"GeFloatModule_basic", "GeFloatModule_basic",
"GeIntModule_basic", "GeIntModule_basic",
@ -1613,6 +1670,7 @@ ONNX_XFAIL_SET = {
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic", "IndexPutImplIndexWithNoneModule_basic",
"IntFloatModule_basic", "IntFloatModule_basic",
"IntImplicitModule_basic",
"IouOfModule_basic", "IouOfModule_basic",
"IsFloatingPointFloat_True", "IsFloatingPointFloat_True",
"IsFloatingPointInt_False", "IsFloatingPointInt_False",
@ -1820,11 +1878,6 @@ ONNX_XFAIL_SET = {
"_SoftmaxModule_basic", "_SoftmaxModule_basic",
# Failure - onnx_import # Failure - onnx_import
"BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic",
"BucketizeTensorOutInt32RightModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"DiagonalModule_basic", "DiagonalModule_basic",
"DiagonalModule_nonsquare", "DiagonalModule_nonsquare",
"DiagonalModule_transposed", "DiagonalModule_transposed",
@ -1832,31 +1885,6 @@ ONNX_XFAIL_SET = {
"DiagonalModule_with_dims_and_offset", "DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims", "DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset", "DiagonalModule_with_offset",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseClampTensorIntModule_basic",
"HBC_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
"IndexPut2DFloatAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"NormalizeModule_basic",
"PadWithNoneValModule_basic",
"QuantizedMLP_basic",
"RandModule_basic",
"ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMinModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf",
"ScatterReduceFloatProdModuleIncludeSelf", "ScatterReduceFloatProdModuleIncludeSelf",
@ -1867,21 +1895,11 @@ ONNX_XFAIL_SET = {
"ScatterReduceIntSumModuleIncludeSelf", "ScatterReduceIntSumModuleIncludeSelf",
"TileBigDimsSizeModule_basic", "TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic", "TileSmallDimsSizeModule_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticSize_basic",
# Failure - onnx_lowering # Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AtenMmFloatTypes_basic",
"AtenMmIntTypes_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
"AtenTriuModule_basic",
"AtenTriuWithNegDiagonalModule_basic",
"AtenTriuWithPosDiagonalModule_basic",
"AvgPool1dFloatModule_basic", "AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic", "AvgPool1dIntModule_basic",
"AvgPool1dStaticModule_basic", "AvgPool1dStaticModule_basic",
@ -1890,78 +1908,73 @@ ONNX_XFAIL_SET = {
"AvgPool2dFloatModule_basic", "AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic", "AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic", "AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic", # Failure - onnx_lowering: onnx.Cast
"BernoulliPModule_basic", "BucketizeTensorOutInt32RightModule_basic",
"BernoulliTensorModule_basic", "ElementwiseToDtypeI64ToI8Module_basic",
"ConstantPad2dStaticModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
"ConstantPadNdModule_basic", "HBC_basic",
"ConstantPadNdPartialStaticModule_basic", "QuantizedMLP_basic",
"ConstantPadNdStaticModule_basic", "TypeConversionI1ToI32Module_basic",
"CrossEntropyLossModule_basic", "TypeConversionI64ToI32Module_basic",
"CrossEntropyLossNoReductionModule_basic",
"DropoutTrainModule_basic", # Failure - onnx_lowering: onnx.Clip
"DropoutTrainStaticShapeModule_basic", "ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseClampTensorIntModule_basic",
"NormalizeModule_basic",
# Failure - onnx_lowering: onnx.Einsum
"EinsumStaticContractRhsModule_basic", "EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic", "EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic", "EinsumStaticModule_basic",
"ElementwiseMishModule_basic",
"ElementwiseRemainderScalarModule_Bool_basic", # Failure - onnx_lowering: onnx.Gemm
"ElementwiseRemainderScalarModule_Int_basic", "AtenMmFloatTypes_basic",
"ElementwiseToDtypeI64ToI8Module_basic", "AtenMmIntTypes_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntNonAccumulateModule_basic",
"IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
"LogSoftmaxIntModule_basic",
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MmDagModule_basic", "MmDagModule_basic",
"MmModule_basic", "MmModule_basic",
"MmModule_chained", "MmModule_chained",
"MmTanhModule_basic", "MmTanhModule_basic",
# Failure - onnx_lowering: onnx.HardSwish
"HardswishModule_basic",
"HardswishRandomModule_basic",
"MobilenetV3Module_basic", "MobilenetV3Module_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"NativeDropoutTrainModule_basic", # Failure - onnx_lowering: onnx.LogSoftmax
"NativeDropoutTrainStaticShapeModule_basic", "LogSoftmaxIntModule_basic",
"_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
# Failure - onnx_lowering: onnx.MaxPool
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
# Failure - onnx_lowering: onnx.Mod
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
# Failure - onnx_lowering: onnx.OneHot
"OneHotModule_basic", "OneHotModule_basic",
# Failure - onnx_lowering: onnx.Pad
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"PadModule_basic", "PadModule_basic",
"RandIntLowDtypeModule_basic", "PadWithNoneValModule_basic",
"RandIntLowModule_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic",
"RandnLikeDtypeModule_basic",
"RandnLikeModule_basic",
"RandnModule_basic",
"ReduceL1NormModule_basic",
"ReduceL1NormWithDTypeModule_basic",
"ReduceL2NormModule_basic",
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ReduceProdDimIntFloatModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left", "ReflectionPad1dModule3dInput_Left",
@ -1976,19 +1989,43 @@ ONNX_XFAIL_SET = {
"ReplicationPad2dModule_left0", "ReplicationPad2dModule_left0",
"ReplicationPad2dModule_right0", "ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0", "ReplicationPad2dModule_top0",
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic", # Failure - onnx_lowering: onnx.RandomNormal
"ScatterValueFloatModule_basic", "RandnDtypeDeviceModule_basic",
"ScatterValueIntModule_basic", "RandnGeneratorF64Module_basic",
"SoftplusModule_basic", "RandnGeneratorModule_basic",
"SortTensorDescending_basic", "RandnModule_basic",
"SortTensorInteger_basic",
"SortTensorNegativeDimension_basic", # Failure - onnx_lowering: onnx.RandomNormalLike
"SortTensorSpecificDimension_basic", "RandnLikeDtypeModule_basic",
"SortTensor_basic", "RandnLikeModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeModule_broadcast", # Failure - onnx_lowering: onnx.RandomUniform
"SqueezeModule_static", "RandIntLowDtypeModule_basic",
"RandIntLowModule_basic",
# Failure - onnx_lowering: onnx.RandomUniformLike
"BernoulliFloatModule_basic",
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RandModule_basic",
# Failure - onnx_lowering: onnx.ReduceL1
"ReduceL1NormModule_basic",
"ReduceL1NormWithDTypeModule_basic",
# Failure - onnx_lowering: onnx.ReduceL2
"ReduceL2NormModule_basic",
# Failure - onnx_lowering: onnx.ReduceProd
"BernoulliModule_basic",
"DropoutTrainModule_basic",
"DropoutTrainStaticShapeModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceProdDimIntFloatModule_basic",
"StdCorrectionAllDimReduceModule_basic", "StdCorrectionAllDimReduceModule_basic",
"StdCorrectionKeepDimModule_basic", "StdCorrectionKeepDimModule_basic",
"StdCorrectionLargeInputModule_basic", "StdCorrectionLargeInputModule_basic",
@ -1999,14 +2036,6 @@ ONNX_XFAIL_SET = {
"StdDimKeepDimTrueModule_basic", "StdDimKeepDimTrueModule_basic",
"StdDimNoneDimModule_basic", "StdDimNoneDimModule_basic",
"StdUnbiasedModule_basic", "StdUnbiasedModule_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"TypeConversionI1ToI32Module_basic",
"TypeConversionI64ToI32Module_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"VarCorrectionAllDimReduceModule_basic", "VarCorrectionAllDimReduceModule_basic",
"VarCorrectionKeepDimModule_basic", "VarCorrectionKeepDimModule_basic",
"VarCorrectionLargeInputModule_basic", "VarCorrectionLargeInputModule_basic",
@ -2025,58 +2054,85 @@ ONNX_XFAIL_SET = {
"VarMeanDimModule_basic", "VarMeanDimModule_basic",
"VarMeanUnbiasedModule_basic", "VarMeanUnbiasedModule_basic",
"VarUnbiasedModule_basic", "VarUnbiasedModule_basic",
"_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
# Failure - cast_error # Failure - onnx_lowering: onnx.ReduceSum
"MeanDimNoneDimModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic",
"MeanDtypeModule_basic", "ReduceL3NormAllDimsModule_basic",
"MeanDynamicSizesModule_basic", "ReduceL3NormKeepDimModule_basic",
"MeanModule_basic", "ReduceSumDtypeFloatModule_basic",
"MseLossMeanReductionModule_basic", "ReduceSumDtypeIntModule_basic",
"StdBiasedModule_basic", "ReduceSumElementTypeBoolModule_basic",
"VarBiasedModule_basic", "ReduceSumFloatModule_basic",
"VarMeanBiasedModule_basic", "ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
# Failure - constant_int # Failure - onnx_lowering: onnx.Resize
"ReduceMinAlongDimNegative_basic", "UpSampleNearest2dDynamicSize_basic",
"ReduceMinAlongDimSignedInt_basic", "UpSampleNearest2dStaticSize_basic",
"ReduceMinAlongDim_basic",
"ReduceMinFloatModule_basic",
"ReduceMinKeepDimReturnBoth_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
# Failure - operand_type # Failure - onnx_lowering: onnx.ScatterElements
"ElementwiseAcosIntModule_basic", "ScatterSrcModule_basic",
"ElementwiseAsinIntModule_basic", "ScatterSrcStaticModule_basic",
"ElementwiseAtanTensorIntModule_basic", "ScatterValueFloatModule_basic",
"ElementwiseCosIntModule_basic", "ScatterValueIntModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseUnaryIntModule_basic",
# Failure - expand_multidim # Failure - onnx_lowering: onnx.ScatterND
"IndexTensorHackedTwinModule3dInput_basic", "IndexPut1DFloatAccumulateModule_basic",
"IndexTensorHackedTwinModule_basic", "IndexPut1DFloatNonAccumulateModule_basic",
"IndexTensorModule3dInput_basic", "IndexPut1DIntAccumulateModule_basic",
"IndexTensorModule_basic", "IndexPut1DIntNonAccumulateModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexPut2DFloatAccumulateModule_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic", "IndexPut2DFloatNonAccumulateModule_basic",
"IndexPut2DIntAccumulateModule_basic",
"IndexPut2DIntNonAccumulateModule_basic",
"IndexPut3DFloatAccumulateModule_basic",
"IndexPut3DFloatNonAccumulateModule_basic",
"IndexPut3DIntAccumulateModule_basic",
"IndexPut3DIntNonAccumulateModule_basic",
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin1DIntAccumulateModule_basic",
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin2DIntAccumulateModule_basic",
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# Failure - rankless_return # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
# Failure - onnx_lowering: onnx.Softplus
"ElementwiseMishModule_basic",
"SoftplusModule_basic",
# Failure - onnx_lowering: onnx.Squeeze
"SqueezeModule_allUnitDim",
"SqueezeModule_broadcast",
"SqueezeModule_static",
# Failure - onnx_lowering: onnx.TopK
"SortTensorDescending_basic",
"SortTensorInteger_basic",
"SortTensorNegativeDimension_basic",
"SortTensorSpecificDimension_basic",
"SortTensor_basic",
# Failure - onnx_lowering: onnx.Trilu
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
"AtenTriuModule_basic",
"AtenTriuWithNegDiagonalModule_basic",
"AtenTriuWithPosDiagonalModule_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
# Failure - rankless return
"ReduceAmaxMultiDim_basic", "ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic", "ReduceAmaxSingleDim_basic",
@ -2089,7 +2145,7 @@ ONNX_XFAIL_SET = {
"ReduceMaxSignedIntModule_basic", "ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic",
# Failure - view_lowering # Failure - torch.aten.view lower
"AddSizeIntModule_basic", "AddSizeIntModule_basic",
"ElementwiseFlattenBroadcastModule_basic", "ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic", "FlattenRank0Module_basic",
@ -2097,13 +2153,11 @@ ONNX_XFAIL_SET = {
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputNonContiguous_basic", "IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic", "IndexTensorMultiInput_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic",
"RepeatModule_basic", "RepeatModule_basic",
"SelectIntModule_basic", "SelectIntModule_basic",
@ -2117,62 +2171,49 @@ ONNX_XFAIL_SET = {
"ViewSizeDimLedByCollapsedOnesModule_basic", "ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeDimLedByExpandedOnesModule_basic", "ViewSizeDimLedByExpandedOnesModule_basic",
# Failure - numerical
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"ElementwiseSeluModule_basic",
"EmbeddingModule1DIndices_basic",
"FlipNegativeIndexModule_basic",
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ResNet18Module_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",
"SliceCopy_Module_basic",
"TupleModule_basic",
# Failure - shape
"ArangeStartOutDtypeModule_basic",
"ArangeStartOutViewModule_basic",
"BroadcastDynamicDimModule_basic",
"BroadcastToModule_basic",
"EmbeddingModuleF16_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"ExpandModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"PermuteNegativeIndexModule_basic",
"ReduceAmaxKeepDim_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ViewSizeFromOtherTensor_basic",
# Failure - onnx traces differently
"ElementwiseSigmoidIntModule_basic",
# Failure - unknown # Failure - unknown
"BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic", "CopyWithDifferentDTypesModule_basic",
"CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticBroadcastModule_basic",
"CumsumInputDtypeInt32Module_basic", "CumsumInputDtypeInt32Module_basic",
"ElementwiseAtan2TensorIntModule_basic", "ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseDivRoundingModeTruncModule_basic", "ElementwiseDivRoundingModeTruncModule_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwisePreluModule_basic", "ElementwisePreluModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseWhereScalarModule_basic", "ElementwiseWhereScalarModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleF16_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"FlattenDynamicModule_basic", "FlattenDynamicModule_basic",
"FlipModuleStaticShape_basic",
"GluStaticModule_basic", "GluStaticModule_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexTensorHackedTwinModule3dInput_basic",
"IndexTensorHackedTwinModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorSelectDimModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"ReduceAllDimEmpty_basic", "ReduceAllDimEmpty_basic",
"ReduceAllDimFloat_basic", "ReduceAllDimFloat_basic",
@ -2180,8 +2221,6 @@ ONNX_XFAIL_SET = {
"ReduceMinAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic", "TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic", "TensorsStackPromoteDTypeModule_basic",
"FloatImplicitModule_basic",
"IntImplicitModule_basic",
} }
ONNX_CRASHING_SET = { } ONNX_CRASHING_SET = { }

View File

@ -926,107 +926,121 @@ func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor
// ----- // -----
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs // CHECK-LABEL: func.func @test_reduce_min_empty_set_fp
func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_reduce_min_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000
// CHECK: %[[INT2:.+]] = torch.constant.int 2 // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int // CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int // CHECK: return %[[FULL]]
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
return %0 : !torch.vtensor<[4,1],i1>
}
// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example
func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,1],f32>
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32>
return %0 : !torch.vtensor<[1,1,1],f32>
}
// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example
func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,2],f32>
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32>
return %0 : !torch.vtensor<[3,2],f32>
}
// CHECK-LABEL: func.func @test_reduce_min_empty_set
func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[2,1,4],f32>
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32>
return %0 : !torch.vtensor<[2,1,4],f32> return %0 : !torch.vtensor<[2,1,4],f32>
} }
// CHECK-LABEL: func.func @test_reduce_min_keepdims_example // -----
func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK-LABEL: func.func @test_reduce_min_empty_set_int
// CHECK: %[[INT3:.+]] = torch.constant.int 3 func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]]
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int> // CHECK: return %[[FULL]]
// CHECK: %[[TRUE:.+]] = torch.constant.bool true %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32>
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[2,1,4],si32>
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32>
return %0 : !torch.vtensor<[3,1,2],f32>
} }
// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example // -----
func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT3:.+]] = torch.constant.int 3 // CHECK-LABEL: func.func @test_reduce_min_bool_inputs
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0 func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[IDX:.+]] = torch.constant.int 0
// CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[SZ:.+]] = torch.constant.int 0
// CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
// CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
// CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: %[[C0:.+]] = torch.constant.int 0
// CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int> // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>
// CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[TRUE:.+]] = torch.constant.bool true
// CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[3,1,2],f32> // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4,1],i1>
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1>
return %0 : !torch.vtensor<[3,1,2],f32> %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1>
return %0 : !torch.vtensor<[4,1],i1>
}
// -----
// CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims
func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[IDX:.+]] = torch.constant.int 0
// CHECK: %[[SZ:.+]] = torch.constant.int 0
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]]
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[C0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
// CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1>
%0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1>
return %0 : !torch.vtensor<[4],i1>
}
// -----
// CHECK-LABEL: func.func @test_reduce_all_dims_default
func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[I0:.+]] = torch.constant.int 0
// CHECK: %[[I1:.+]] = torch.constant.int 1
// CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[C0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]]
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[MIN]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list<int>, !torch.bool -> !torch.vtensor<[4],i1>
// CHECK: return %[[AMIN]]
%0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1>
return %0 : !torch.vtensor<[4],i1>
} }
// ----- // -----