From 9f5ddb78c71dce574338c87bab562a9118344b5e Mon Sep 17 00:00:00 2001 From: Frederik Harwath Date: Fri, 19 Jan 2024 05:25:43 -0800 Subject: [PATCH] Simplify lowering of aten.reflection_pad2d to linalg The lowering uses a linalg.extract_slice operation to extract the input slices to be used for the padding followed by a linalg.generic operation to reverse the order of elements as required for the reflection. This commit removes the use of the linalg.generic. It uses the capability of the linalg.extract_slice operation to extract a reversed slice by specifying negative strides and offsets that point to the end of the slice in the input tensor. This simplifies the generated IR and makes it potentially more optimization friendly. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 127 ++++++++---------- 1 file changed, 55 insertions(+), 72 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 49f5f0ec3..79ed79025 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -247,8 +247,7 @@ public: namespace { // Lower the aten.reflection.pad_2d operator into a sequence of -// tensor.extract_slice, linalg.generic, and tensor_insert_slice -// operations. +// tensor.extract_slice and tensor_insert_slice operations. // To understand the lowering, consider this pytorch example: // @@ -282,8 +281,8 @@ namespace { // center right: [[2,1]] // // The lowering uses a tensor.extract_slice operation to create each tile, -// a linalg.generic for the reflection, and a tensor.insert_slice to -// insert the tile in the resulting tensor. +// including the reversal of the order of the elements if necessary, +// and a tensor.insert_slice to insert the tile in the result tensor. class ConvertAtenReflectionPad2dOp : public OpConversionPattern { public: @@ -305,20 +304,12 @@ public: return rewriter.create(loc, x, y); }; - auto createAdds = [&](std::initializer_list values) { - assert(values.size() >= 2); - return std::accumulate(values.begin() + 1, values.end(), data(values)[0], - createAdd); - }; - auto createSub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; - auto createSubs = [&](std::initializer_list values) { - assert(values.size() >= 2); - return std::accumulate(values.begin() + 1, values.end(), data(values)[0], - createSub); + auto getIndexConst = [&](int c) { + return rewriter.create(loc, c); }; // Enums for specifying the coordinates of a tile. An "h" prefix @@ -349,7 +340,6 @@ public: }; Value input = adaptor.getSelf(); - MLIRContext *context = rewriter.getContext(); auto inputType = llvm::cast(input.getType()); auto outputType = llvm::cast( getTypeConverter()->convertType(op->getResult(0).getType())); @@ -372,37 +362,27 @@ public: assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && "Bottom padding too large"); - Type indexType = rewriter.getIndexType(); - Value zero = getConstant(rewriter, loc, 0, indexType); - Value one = getConstant(rewriter, loc, 1, indexType); + Value zero = getIndexConst(0); + Value one = getIndexConst(1); + Value two = getIndexConst(2); + Value minusOne = getIndexConst(-1); Value tileWidth[3]; tileWidth[HCENTER] = hDimSize; for (auto h : {LEFT, RIGHT}) - tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType); + tileWidth[h] = getIndexConst(getHPadArgument(h)); Value tileHeight[3]; tileHeight[VCENTER] = vDimSize; for (auto v : {TOP, BOTTOM}) - tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType); - - // Helper to reflect/reverse the i-th dimension of an affine map - // without symbols. This only works if applied on a tensor - // for which the corresponding dimension has a statically - // known size which is good enough since we only apply - // it to reflect the padding slices. - auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, - int64_t size) { - AffineExpr d = map.getResult(i); - return map.replace(d, size - d - 1, numDims, 0); - }; + tileHeight[v] = getIndexConst(getVPadArgument(v)); // Create output shape and tensor SmallVector resultShape{inputShape}; - resultShape[vDim] = - createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]}); - resultShape[hDim] = - createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]}); + resultShape[vDim] = createAdd(createAdd(resultShape[vDim], tileHeight[TOP]), + tileHeight[BOTTOM]); + resultShape[hDim] = createAdd(createAdd(resultShape[hDim], tileWidth[LEFT]), + tileWidth[RIGHT]); Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); @@ -444,18 +424,29 @@ public: // Setup information about the tiles - // Compute the offsets for extracting the slice from the - // input. We need to skip the row or column through which - // the tile should be reflected, if any (none for the center tile). + // Compute the offsets for extracting the slice from the input. To + // reverse the order of the elements in the non-central tiles, + // extract the slices with negative strides and start from the + // last element of the input that should belong to the slice, + // skipping the "axis" element through which the elements are + // reflected: + // + // - The left tile is obtained by extracting elements + // tileWidth[LEFT] + 1, ..., 2 in this, i.e. reverse order. + // + // - The right tile is obtained by extracting elements + // hDimSize - 1, ..., hDimSize - tileWidth[RIGHT] - 1 in this, + // i.e. reverse order. + Value extractHOffset[3]; - extractHOffset[LEFT] = one; + extractHOffset[LEFT] = tileWidth[LEFT]; extractHOffset[HCENTER] = zero; - extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one}); + extractHOffset[RIGHT] = createSub(hDimSize, two); Value extractVOffset[3]; - extractVOffset[TOP] = one; + extractVOffset[TOP] = tileHeight[TOP]; extractVOffset[VCENTER] = zero; - extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one}); + extractVOffset[BOTTOM] = createSub(vDimSize, two); // Compute the horizontal and vertical offsets for inserting // the tiles in the resultTensor. @@ -469,47 +460,39 @@ public: insertVOffset[VCENTER] = tileHeight[TOP]; insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]); - auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; }; - auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; }; + // Define the strides for the tensor.extract_slice operations. + // Using a negative stride for a dimension reverses the order + // of the extracted elements as necessary for the reflection. + Value extractHStride[3]; + extractHStride[LEFT] = minusOne; + extractHStride[HCENTER] = one; + extractHStride[RIGHT] = minusOne; + + Value extractVStride[3]; + extractVStride[TOP] = minusOne; + extractVStride[VCENTER] = one; + extractVStride[BOTTOM] = minusOne; SmallVector iteratorTypes{ numDims, utils::IteratorType::parallel}; - auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) { - // Create the tile by extracting a slice from the input tenor. - SmallVector extractShape{inputShape}; - extractShape[hDim] = tileWidth[horizontalPos]; - extractShape[vDim] = tileHeight[verticalPos]; + SmallVector tileShape{inputShape}; + tileShape[hDim] = tileWidth[horizontalPos]; + tileShape[vDim] = tileHeight[verticalPos]; + // Create the tile by extracting a slice from the input tenor. SmallVector extractOffsets(numDims, zero); extractOffsets[hDim] = extractHOffset[horizontalPos]; extractOffsets[vDim] = extractVOffset[verticalPos]; + SmallVector extractStrides(numDims, one); + extractStrides[hDim] = extractHStride[horizontalPos]; + extractStrides[vDim] = extractVStride[verticalPos]; + Value tile = rewriter.create( - loc, input, extractOffsets, extractShape, allOneStrides); - - // Reverse the tile along the horizontal, vertical, or both - // dimensions. - auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); - if (shouldHReflect(horizontalPos)) { - inputMap = - reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos)); - } - if (shouldVReflect(verticalPos)) { - inputMap = - reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos)); - } - - tile = rewriter - .create( - loc, llvm::cast(tile.getType()), tile, - tile, ArrayRef({inputMap, idMap}), iteratorTypes, - [](OpBuilder &b, Location nestedLoc, ValueRange args) { - b.create(nestedLoc, args[0]); - }) - .getResult(0); + loc, input, extractOffsets, tileShape, extractStrides); // Insert the tile in the resultTensor. SmallVector insertOffsets(numDims, zero); @@ -517,7 +500,7 @@ public: insertOffsets[vDim] = insertVOffset[verticalPos]; resultTensor = rewriter.create( - loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + loc, tile, resultTensor, insertOffsets, tileShape, allOneStrides); }; for (auto v : {TOP, BOTTOM, VCENTER})