[TORCH][MLIR] Fix some comments in slice_scatter/select_scatter

lowering.

This commit addresses the remaining comments on lowering of
slice_scatter and select_scatter.

Signed-Off-By: Prateek Gupta <gprateek93@gmail.com>
pull/1051/head snapshot-20220713.532
Prateek Gupta 2022-07-12 12:43:03 +05:30
parent ac4d7d10e0
commit 3592e0ba7f
1 changed files with 13 additions and 13 deletions

View File

@ -34,28 +34,28 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
Location loc, Value torchType,
Value builtinType, Value valueForNone,
Location loc, Value torchOptionalInt,
Value builtinInt, Value defaultValue,
Value dimSize) {
if (torchType.getType().isa<Torch::NoneType>())
return valueForNone;
if (torchOptionalInt.getType().isa<Torch::NoneType>())
return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =
toPositiveDimDynamic(rewriter, loc, builtinType, dimSizeAsInt);
// startOrEnd < 0 ? 0 : startOrEnd
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
// positveDim < 0 ? 0 : positiveDim
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
Value startOrEndAtLeastZero =
Value atLeastZero =
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
return castIntToIndex(rewriter, loc, boundedByDimSize);
}
template <typename OpTy, typename OpAdaptor>