mirror of https://github.com/llvm/torch-mlir
[torch] Fixed edge conditions for strided slicing (#2929)
Strided slicing can occur with a negative stride. In these cases we need to bound end differently. This included removing a function that was generating bad limits.pull/2944/head
parent
0f80e75c2e
commit
df2aa1a369
|
@ -309,29 +309,6 @@ inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
|
|||
return intAttr.getValue().getSExtValue();
|
||||
}
|
||||
|
||||
/// Returns the value from an `IntegerAttr` as an integral index.
|
||||
///
|
||||
/// @param intAttr the `IntegerAttr` from which to extract the index
|
||||
/// @param dimSize the size of the dimension that the attribute indexes into
|
||||
/// @return the index value
|
||||
///
|
||||
/// Use this function when the given `IntegerAttr` represents an index into
|
||||
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
|
||||
/// negative index values are converted into positive vales by counting
|
||||
/// elements from the "right" side of the dimension, as in python, numpy, etc.
|
||||
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
|
||||
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
|
||||
/// given, any negative indices are returned as negative numbers.
|
||||
///
|
||||
/// No bounds checking is performed on the index to ensure that it is within
|
||||
/// the legal range for `dimSize`.
|
||||
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
|
||||
int64_t signedIndex = getIntAttrAsSigned(intAttr);
|
||||
if (dimSize < 0 || signedIndex > 0)
|
||||
return signedIndex;
|
||||
return dimSize + signedIndex; // count backwards from dimSize
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -51,6 +51,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
|
@ -76,27 +77,49 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
|||
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
|
||||
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
||||
builtinTypeStart, zero, dimSize);
|
||||
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
||||
dimSize, dimSize);
|
||||
|
||||
// end >= start ? end : start
|
||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, end, start);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||
// We cannot use to positive valid dim as for negative strides we need to
|
||||
// clamp to `-1` so that the full tensor bounds are available:
|
||||
Value end = builtinTypeEnd;
|
||||
if (torchTypeEnd.getType().isa<Torch::NoneType>()) {
|
||||
end = dimSize;
|
||||
} else {
|
||||
end = castIntToIndex(rewriter, loc, end);
|
||||
Value endcmp = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, end, zero);
|
||||
Value endadd = rewriter.create<arith::AddIOp>(loc, end, dimSize);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endcmp, endadd, end);
|
||||
endcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, end,
|
||||
zero);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endcmp, negone, end);
|
||||
endcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, end,
|
||||
dimSize);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endcmp, dimSize, end);
|
||||
}
|
||||
|
||||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||
resultShape = getTensorSizes(rewriter, loc, input);
|
||||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||
|
||||
// We check the difference between start and end to determine the total size:
|
||||
Value stepcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
stepIndex, zero);
|
||||
Value stepsign = rewriter.create<arith::SelectOp>(loc, stepcmp, one, negone);
|
||||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, stepsign);
|
||||
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||
|
||||
// Clamp the size to [0, ...]:
|
||||
Value szcmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
resultSize, zero);
|
||||
resultSize = rewriter.create<arith::SelectOp>(loc, szcmp, zero, resultSize);
|
||||
resultShape[dim] = resultSize;
|
||||
|
||||
strides.resize(inputType.getRank(), one);
|
||||
offsets.resize(inputType.getRank(), zero);
|
||||
|
||||
offsets[dim] = start;
|
||||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||
strides[dim] = stepIndex;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -3327,11 +3327,15 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
// Get the single index value for the selected dimension
|
||||
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
|
||||
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);
|
||||
int64_t indexInt = getIntAttrAsSigned(splatValue);
|
||||
indexInt = indexInt < 0 && selfSizes[dimInt] ? indexInt + selfSizes[dimInt]
|
||||
: indexInt;
|
||||
|
||||
// Extract the single constant value from the input tensor and turn the
|
||||
// extracted value into a single-element tensor of the output shape and dtype
|
||||
auto splattr = selfAttr.getValues<Attribute>()[indexInt];
|
||||
Attribute splattr = selfAttr.isSplat()
|
||||
? selfAttr.getSplatValue<Attribute>()
|
||||
: selfAttr.getValues<Attribute>()[indexInt];
|
||||
|
||||
auto dty = resultTy.getDtype();
|
||||
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
|
||||
|
|
|
@ -2162,6 +2162,8 @@ ONNX_XFAIL_SET = {
|
|||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"ExpandModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"ReduceAmaxKeepDim_basic",
|
||||
"ReduceMaxKeepDimReturnBoth_basic",
|
||||
"ReduceMaxNegativeDim_basic",
|
||||
|
@ -2184,7 +2186,6 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"FlattenDynamicModule_basic",
|
||||
"FlipModule_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"GluStaticModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
|
@ -2194,9 +2195,7 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"TensorsStackNegativeDimModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"FlipModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
}
|
||||
|
||||
ONNX_CRASHING_SET = { }
|
||||
|
||||
|
|
Loading…
Reference in New Issue