mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (#3763)
This commit adds the support for negative step values in aten.slice.Tensor op. Although, PyTorch does not allow negative step value for slice op but the Onnx.Slice op supports negative step value which eventually lowers to torch.aten.slice.Tensor op. Hence, the support is added for handling those kind of values during the Torch->Linalg lowering of aten.slice.Tensor op. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3776/head
parent
b08d08682f
commit
f6721e5999
|
@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
|
||||||
Location loc, SmallVector<int64_t> dimensions,
|
Location loc, SmallVector<int64_t> dimensions,
|
||||||
Value input, Value &result);
|
Value input, Value &result);
|
||||||
|
|
||||||
|
// Flips an input tensor based on the values of axis list.
|
||||||
|
Value flipTensor(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
|
SmallVector<int64_t> axis);
|
||||||
|
|
||||||
} // namespace torch_to_linalg
|
} // namespace torch_to_linalg
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef<int64_t> a) {
|
||||||
template <typename OpTy, typename OpAdaptor>
|
template <typename OpTy, typename OpAdaptor>
|
||||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
|
int64_t &dim,
|
||||||
SmallVector<Value> &resultShape,
|
SmallVector<Value> &resultShape,
|
||||||
SmallVector<Value> &offsets,
|
SmallVector<Value> &offsets,
|
||||||
SmallVector<Value> &strides) {
|
SmallVector<Value> &strides) {
|
||||||
|
@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
Value negone = rewriter.create<arith::ConstantIndexOp>(loc, -1);
|
||||||
|
|
||||||
int64_t dim;
|
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return op->emitError("unimplemented: dim is not constant");
|
return op->emitError("unimplemented: dim is not constant");
|
||||||
|
|
||||||
|
@ -1857,14 +1857,46 @@ public:
|
||||||
RankedTensorType resultType = cast<RankedTensorType>(
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType()));
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape, offsets, strides;
|
||||||
SmallVector<Value> offsets;
|
int64_t dim;
|
||||||
SmallVector<Value> strides;
|
|
||||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
||||||
AtenSliceTensorOpAdaptor>(
|
AtenSliceTensorOpAdaptor>(
|
||||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If stride is negative, then flip the input tensor corresponding to that
|
||||||
|
// dim, update the stride for flipped tensor by multiplying it by -1, and
|
||||||
|
// update the offset as follows:
|
||||||
|
// flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride)
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
// Input = [0, 1, 2, 3, 4, 5]
|
||||||
|
// stride = [-2], result_shape = [2], offset = [3]
|
||||||
|
// Result = [3, 1]
|
||||||
|
// After flipping:
|
||||||
|
// Input = [5, 4, 3, 2, 1, 0]
|
||||||
|
// stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2]
|
||||||
|
// Result = [3, 1]
|
||||||
|
|
||||||
|
Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input,
|
||||||
|
SmallVector<int64_t>{dim});
|
||||||
|
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
|
||||||
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value isNegativeStride = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, strides[dim], zero);
|
||||||
|
strides[dim] = rewriter.create<math::AbsIOp>(loc, strides[dim]);
|
||||||
|
Value resShapeMulStride =
|
||||||
|
rewriter.create<arith::MulIOp>(loc, resultShape[dim], strides[dim]);
|
||||||
|
Value inputDim = rewriter.create<tensor::DimOp>(loc, input, cstDim);
|
||||||
|
Value flippedOffset =
|
||||||
|
rewriter.create<arith::SubIOp>(loc, inputDim, resShapeMulStride);
|
||||||
|
offsets[dim] = rewriter.create<arith::SelectOp>(
|
||||||
|
loc, isNegativeStride, flippedOffset, offsets[dim]);
|
||||||
|
|
||||||
|
input = rewriter.create<arith::SelectOp>(loc, isNegativeStride,
|
||||||
|
flippedInput, input);
|
||||||
|
|
||||||
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
|
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
|
||||||
auto sliceType = RankedTensorType::get(
|
auto sliceType = RankedTensorType::get(
|
||||||
dynShape, resultType.getElementType(), resultType.getEncoding());
|
dynShape, resultType.getElementType(), resultType.getEncoding());
|
||||||
|
@ -2095,12 +2127,11 @@ public:
|
||||||
RankedTensorType resultType = cast<RankedTensorType>(
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
typeConverter->convertType(op->getResult(0).getType()));
|
typeConverter->convertType(op->getResult(0).getType()));
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape, offsets, strides;
|
||||||
SmallVector<Value> offsets;
|
int64_t dim;
|
||||||
SmallVector<Value> strides;
|
|
||||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||||
AtenSliceScatterOpAdaptor>(
|
AtenSliceScatterOpAdaptor>(
|
||||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
op, adaptor, rewriter, dim, resultShape, offsets, strides))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -222,14 +222,9 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfRank =
|
auto selfRank =
|
||||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
Type elementType =
|
|
||||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
|
|
||||||
Value c1 =
|
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
||||||
|
|
||||||
SmallVector<int64_t> axis;
|
SmallVector<int64_t> axis;
|
||||||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
||||||
|
@ -242,40 +237,8 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis);
|
||||||
// dims won't be used.
|
|
||||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
|
||||||
for (auto flipDim : axis)
|
|
||||||
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
|
||||||
|
|
||||||
Value initTensor = createZeroInitTensor(
|
|
||||||
rewriter, loc, getTensorSizes(rewriter, loc, self), elementType);
|
|
||||||
|
|
||||||
SmallVector<utils::IteratorType> iteratorTypes(
|
|
||||||
selfRank, utils::IteratorType::parallel);
|
|
||||||
SmallVector<AffineMap> indexingMaps(
|
|
||||||
2, AffineMap::getMultiDimIdentityMap(selfRank, context));
|
|
||||||
Value flipped =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, self.getType(), self, initTensor, indexingMaps,
|
|
||||||
iteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
SmallVector<Value> indices;
|
|
||||||
for (auto i = 0; i < selfRank; i++)
|
|
||||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
||||||
for (auto flipDim : axis) {
|
|
||||||
indices[flipDim] = b.create<arith::SubIOp>(
|
|
||||||
loc, dims[flipDim], indices[flipDim]);
|
|
||||||
}
|
|
||||||
Value res = b.create<tensor::ExtractOp>(loc, self, indices)
|
|
||||||
.getResult();
|
|
||||||
b.create<linalg::YieldOp>(loc, res);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, self.getType(), flipped);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op,
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flips an input tensor based on the values of axis list.
|
||||||
|
Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc,
|
||||||
|
Value input, SmallVector<int64_t> axis) {
|
||||||
|
Value c1 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||||
|
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
|
auto selfRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
|
|
||||||
|
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||||
|
// dims won't be used.
|
||||||
|
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
||||||
|
for (auto flipDim : axis)
|
||||||
|
dims[flipDim] = rewriter.create<arith::SubIOp>(loc, dims[flipDim], c1);
|
||||||
|
|
||||||
|
Value initTensor = createZeroInitTensor(
|
||||||
|
rewriter, loc, getTensorSizes(rewriter, loc, input), elementType);
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(selfRank,
|
||||||
|
utils::IteratorType::parallel);
|
||||||
|
SmallVector<AffineMap> indexingMaps(
|
||||||
|
2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext()));
|
||||||
|
Value flipped =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, input.getType(), input, initTensor, indexingMaps,
|
||||||
|
iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
for (auto i = 0; i < selfRank; i++)
|
||||||
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
for (auto flipDim : axis) {
|
||||||
|
indices[flipDim] = b.create<arith::SubIOp>(loc, dims[flipDim],
|
||||||
|
indices[flipDim]);
|
||||||
|
}
|
||||||
|
Value res = b.create<tensor::ExtractOp>(loc, input, indices)
|
||||||
|
.getResult();
|
||||||
|
b.create<linalg::YieldOp>(loc, res);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
return flipped;
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue