From f6721e599961a36d67236fce9f58cdd719c9cef4 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 8 Oct 2024 10:34:27 +0530 Subject: [PATCH] [MLIR][TORCH] Add support for negative step in aten.slice.Tensor op (#3763) This commit adds the support for negative step values in aten.slice.Tensor op. Although, PyTorch does not allow negative step value for slice op but the Onnx.Slice op supports negative step value which eventually lowers to torch.aten.slice.Tensor op. Hence, the support is added for handling those kind of values during the Torch->Linalg lowering of aten.slice.Tensor op. Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchToLinalg/Utils.h | 4 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 49 +++++++++++++++---- lib/Conversion/TorchToLinalg/Linear.cpp | 39 +-------------- lib/Conversion/TorchToLinalg/Utils.cpp | 41 ++++++++++++++++ 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 14e920222..b59d183b4 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -101,6 +101,10 @@ LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter, Location loc, SmallVector dimensions, Value input, Value &result); +// Flips an input tensor based on the values of axis list. +Value flipTensor(PatternRewriter &rewriter, Location loc, Value input, + SmallVector axis); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5542e0fc6..ac1707ec2 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -40,6 +40,7 @@ static int64_t productReduce(ArrayRef a) { template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + int64_t &dim, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { @@ -51,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value one = rewriter.create(loc, 1); Value negone = rewriter.create(loc, -1); - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); @@ -1857,14 +1857,46 @@ public: RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } + + // If stride is negative, then flip the input tensor corresponding to that + // dim, update the stride for flipped tensor by multiplying it by -1, and + // update the offset as follows: + // flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride) + // + // For example: + // Input = [0, 1, 2, 3, 4, 5] + // stride = [-2], result_shape = [2], offset = [3] + // Result = [3, 1] + // After flipping: + // Input = [5, 4, 3, 2, 1, 0] + // stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2] + // Result = [3, 1] + + Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input, + SmallVector{dim}); + Value cstDim = rewriter.create(loc, dim); + Value zero = rewriter.create(loc, 0); + Value isNegativeStride = rewriter.create( + loc, arith::CmpIPredicate::slt, strides[dim], zero); + strides[dim] = rewriter.create(loc, strides[dim]); + Value resShapeMulStride = + rewriter.create(loc, resultShape[dim], strides[dim]); + Value inputDim = rewriter.create(loc, input, cstDim); + Value flippedOffset = + rewriter.create(loc, inputDim, resShapeMulStride); + offsets[dim] = rewriter.create( + loc, isNegativeStride, flippedOffset, offsets[dim]); + + input = rewriter.create(loc, isNegativeStride, + flippedInput, input); + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); auto sliceType = RankedTensorType::get( dynShape, resultType.getElementType(), resultType.getEncoding()); @@ -2095,12 +2127,11 @@ public: RankedTensorType resultType = cast( typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 52765411b..fc910fa9d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -222,14 +222,9 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = cast(adaptor.getSelf().getType()).getRank(); - Type elementType = - cast(adaptor.getSelf().getType()).getElementType(); - Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector axis; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) @@ -242,40 +237,8 @@ public: } } - // Only used to calculate flipped values, i.e. those on the flip axes. Other - // dims won't be used. - SmallVector dims = getTensorSizes(rewriter, loc, self); - for (auto flipDim : axis) - dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); - - Value initTensor = createZeroInitTensor( - rewriter, loc, getTensorSizes(rewriter, loc, self), elementType); - - SmallVector iteratorTypes( - selfRank, utils::IteratorType::parallel); - SmallVector indexingMaps( - 2, AffineMap::getMultiDimIdentityMap(selfRank, context)); - Value flipped = - rewriter - .create( - loc, self.getType(), self, initTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (auto i = 0; i < selfRank; i++) - indices.push_back(b.create(loc, i)); - for (auto flipDim : axis) { - indices[flipDim] = b.create( - loc, dims[flipDim], indices[flipDim]); - } - Value res = b.create(loc, self, indices) - .getResult(); - b.create(loc, res); - }) - .getResult(0); - + Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis); rewriter.replaceOpWithNewOp(op, self.getType(), flipped); - return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 6ef947d89..18e8fb449 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -620,3 +620,44 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, .getResult(0); return success(); } + +// Flips an input tensor based on the values of axis list. +Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, + Value input, SmallVector axis) { + Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); + Type elementType = cast(input.getType()).getElementType(); + auto selfRank = cast(input.getType()).getRank(); + + // Only used to calculate flipped values, i.e. those on the flip axes. Other + // dims won't be used. + SmallVector dims = getTensorSizes(rewriter, loc, input); + for (auto flipDim : axis) + dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); + + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, input), elementType); + + SmallVector iteratorTypes(selfRank, + utils::IteratorType::parallel); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext())); + Value flipped = + rewriter + .create( + loc, input.getType(), input, initTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (auto i = 0; i < selfRank; i++) + indices.push_back(b.create(loc, i)); + for (auto flipDim : axis) { + indices[flipDim] = b.create(loc, dims[flipDim], + indices[flipDim]); + } + Value res = b.create(loc, input, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + return flipped; +}