[torch] Fixed edge conditions for strided slicing (#2929)

Strided slicing can occur with a negative stride. In these cases we need
to bound end differently. This included removing a function that was
generating bad limits.
pull/2944/head
Rob Suderman 2024-02-21 21:28:44 -08:00 committed by GitHub
parent 0f80e75c2e
commit df2aa1a369
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 37 deletions

View File

@ -309,29 +309,6 @@ inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
return intAttr.getValue().getSExtValue();
}
/// Returns the value from an `IntegerAttr` as an integral index.
///
/// @param intAttr the `IntegerAttr` from which to extract the index
/// @param dimSize the size of the dimension that the attribute indexes into
/// @return the index value
///
/// Use this function when the given `IntegerAttr` represents an index into
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
/// negative index values are converted into positive vales by counting
/// elements from the "right" side of the dimension, as in python, numpy, etc.
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
/// given, any negative indices are returned as negative numbers.
///
/// No bounds checking is performed on the index to ensure that it is within
/// the legal range for `dimSize`.
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
int64_t signedIndex = getIntAttrAsSigned(intAttr);
if (dimSize < 0 || signedIndex > 0)
return signedIndex;
return dimSize + signedIndex; // count backwards from dimSize
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -51,6 +51,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
@ -76,27 +77,49 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
builtinTypeStart, zero, dimSize);
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
dimSize, dimSize);
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
// We cannot use to positive valid dim as for negative strides we need to
// clamp to `-1` so that the full tensor bounds are available:
Value end = builtinTypeEnd;
if (torchTypeEnd.getType().isa<Torch::NoneType>()) {
end = dimSize;
} else {
end = castIntToIndex(rewriter, loc, end);
Value endcmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, end, zero);
Value endadd = rewriter.create<arith::AddIOp>(loc, end, dimSize);
end = rewriter.create<arith::SelectOp>(loc, endcmp, endadd, end);
endcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, end,
zero);
end = rewriter.create<arith::SelectOp>(loc, endcmp, negone, end);
endcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, end,
dimSize);
end = rewriter.create<arith::SelectOp>(loc, endcmp, dimSize, end);
}
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
resultShape = getTensorSizes(rewriter, loc, input);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
// We check the difference between start and end to determine the total size:
Value stepcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
stepIndex, zero);
Value stepsign = rewriter.create<arith::SelectOp>(loc, stepcmp, one, negone);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, stepsign);
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
// Clamp the size to [0, ...]:
Value szcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
resultSize, zero);
resultSize = rewriter.create<arith::SelectOp>(loc, szcmp, zero, resultSize);
resultShape[dim] = resultSize;
strides.resize(inputType.getRank(), one);
offsets.resize(inputType.getRank(), zero);
offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
strides[dim] = stepIndex;
return success();
}

View File

@ -3327,11 +3327,15 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
// Get the single index value for the selected dimension
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);
int64_t indexInt = getIntAttrAsSigned(splatValue);
indexInt = indexInt < 0 && selfSizes[dimInt] ? indexInt + selfSizes[dimInt]
: indexInt;
// Extract the single constant value from the input tensor and turn the
// extracted value into a single-element tensor of the output shape and dtype
auto splattr = selfAttr.getValues<Attribute>()[indexInt];
Attribute splattr = selfAttr.isSplat()
? selfAttr.getSplatValue<Attribute>()
: selfAttr.getValues<Attribute>()[indexInt];
auto dty = resultTy.getDtype();
auto attrTy = resultTy.toBuiltinTensor().clone(dty);

View File

@ -2162,6 +2162,8 @@ ONNX_XFAIL_SET = {
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"ExpandModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"PermuteNegativeIndexModule_basic",
"ReduceAmaxKeepDim_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
@ -2184,7 +2186,6 @@ ONNX_XFAIL_SET = {
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseWhereScalarModule_basic",
"FlattenDynamicModule_basic",
"FlipModule_basic",
"FlipModuleStaticShape_basic",
"GluStaticModule_basic",
"MaskedFillTensorFloatValueModule_basic",
@ -2194,9 +2195,7 @@ ONNX_XFAIL_SET = {
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"FlipModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"PermuteNegativeIndexModule_basic",
}
ONNX_CRASHING_SET = { }