mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Fix some comments in slice_scatter/select_scatter
lowering. This commit addresses the remaining comments on lowering of slice_scatter and select_scatter. Signed-Off-By: Prateek Gupta <gprateek93@gmail.com>pull/1051/head snapshot-20220713.532
parent
ac4d7d10e0
commit
3592e0ba7f
|
@ -34,28 +34,28 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
|
||||
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value torchType,
|
||||
Value builtinType, Value valueForNone,
|
||||
Location loc, Value torchOptionalInt,
|
||||
Value builtinInt, Value defaultValue,
|
||||
Value dimSize) {
|
||||
if (torchType.getType().isa<Torch::NoneType>())
|
||||
return valueForNone;
|
||||
if (torchOptionalInt.getType().isa<Torch::NoneType>())
|
||||
return defaultValue;
|
||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||
Value positiveDim =
|
||||
toPositiveDimDynamic(rewriter, loc, builtinType, dimSizeAsInt);
|
||||
// startOrEnd < 0 ? 0 : startOrEnd
|
||||
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
|
||||
// positveDim < 0 ? 0 : positiveDim
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
||||
Value startOrEndAtLeastZero =
|
||||
Value atLeastZero =
|
||||
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
||||
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
||||
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
||||
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
||||
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
|
||||
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
|
||||
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||
return castIntToIndex(rewriter, loc, boundedByDimSize);
|
||||
}
|
||||
|
||||
template <typename OpTy, typename OpAdaptor>
|
||||
|
|
Loading…
Reference in New Issue