Simplify lowering of aten.reflection_pad2d to linalg

The lowering uses a linalg.extract_slice operation to extract the
input slices to be used for the padding followed by a linalg.generic
operation to reverse the order of elements as required for the
reflection.

This commit removes the use of the linalg.generic. It uses the
capability of the linalg.extract_slice operation to extract a reversed
slice by specifying negative strides and offsets that point to the end
of the slice in the input tensor. This simplifies the generated IR
and makes it potentially more optimization friendly.
pull/2772/head
Frederik Harwath 2024-01-19 05:25:43 -08:00
parent 5862854bc8
commit 9f5ddb78c7
1 changed files with 55 additions and 72 deletions

View File

@ -247,8 +247,7 @@ public:
namespace {
// Lower the aten.reflection.pad_2d operator into a sequence of
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
// operations.
// tensor.extract_slice and tensor_insert_slice operations.
// To understand the lowering, consider this pytorch example:
//
@ -282,8 +281,8 @@ namespace {
// center right: [[2,1]]
//
// The lowering uses a tensor.extract_slice operation to create each tile,
// a linalg.generic for the reflection, and a tensor.insert_slice to
// insert the tile in the resulting tensor.
// including the reversal of the order of the elements if necessary,
// and a tensor.insert_slice to insert the tile in the result tensor.
class ConvertAtenReflectionPad2dOp
: public OpConversionPattern<AtenReflectionPad2dOp> {
public:
@ -305,20 +304,12 @@ public:
return rewriter.create<arith::AddIOp>(loc, x, y);
};
auto createAdds = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createAdd);
};
auto createSub = [&](Value x, Value y) {
return rewriter.create<arith::SubIOp>(loc, x, y);
};
auto createSubs = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createSub);
auto getIndexConst = [&](int c) {
return rewriter.create<arith::ConstantIndexOp>(loc, c);
};
// Enums for specifying the coordinates of a tile. An "h" prefix
@ -349,7 +340,6 @@ public:
};
Value input = adaptor.getSelf();
MLIRContext *context = rewriter.getContext();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
auto outputType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
@ -372,37 +362,27 @@ public:
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
"Bottom padding too large");
Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);
Value zero = getIndexConst(0);
Value one = getIndexConst(1);
Value two = getIndexConst(2);
Value minusOne = getIndexConst(-1);
Value tileWidth[3];
tileWidth[HCENTER] = hDimSize;
for (auto h : {LEFT, RIGHT})
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);
tileWidth[h] = getIndexConst(getHPadArgument(h));
Value tileHeight[3];
tileHeight[VCENTER] = vDimSize;
for (auto v : {TOP, BOTTOM})
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);
// Helper to reflect/reverse the i-th dimension of an affine map
// without symbols. This only works if applied on a tensor
// for which the corresponding dimension has a statically
// known size which is good enough since we only apply
// it to reflect the padding slices.
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
int64_t size) {
AffineExpr d = map.getResult(i);
return map.replace(d, size - d - 1, numDims, 0);
};
tileHeight[v] = getIndexConst(getVPadArgument(v));
// Create output shape and tensor
SmallVector<Value> resultShape{inputShape};
resultShape[vDim] =
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
resultShape[hDim] =
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
resultShape[vDim] = createAdd(createAdd(resultShape[vDim], tileHeight[TOP]),
tileHeight[BOTTOM]);
resultShape[hDim] = createAdd(createAdd(resultShape[hDim], tileWidth[LEFT]),
tileWidth[RIGHT]);
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
inputType.getElementType());
@ -444,18 +424,29 @@ public:
// Setup information about the tiles
// Compute the offsets for extracting the slice from the
// input. We need to skip the row or column through which
// the tile should be reflected, if any (none for the center tile).
// Compute the offsets for extracting the slice from the input. To
// reverse the order of the elements in the non-central tiles,
// extract the slices with negative strides and start from the
// last element of the input that should belong to the slice,
// skipping the "axis" element through which the elements are
// reflected:
//
// - The left tile is obtained by extracting elements
// tileWidth[LEFT] + 1, ..., 2 in this, i.e. reverse order.
//
// - The right tile is obtained by extracting elements
// hDimSize - 1, ..., hDimSize - tileWidth[RIGHT] - 1 in this,
// i.e. reverse order.
Value extractHOffset[3];
extractHOffset[LEFT] = one;
extractHOffset[LEFT] = tileWidth[LEFT];
extractHOffset[HCENTER] = zero;
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});
extractHOffset[RIGHT] = createSub(hDimSize, two);
Value extractVOffset[3];
extractVOffset[TOP] = one;
extractVOffset[TOP] = tileHeight[TOP];
extractVOffset[VCENTER] = zero;
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});
extractVOffset[BOTTOM] = createSub(vDimSize, two);
// Compute the horizontal and vertical offsets for inserting
// the tiles in the resultTensor.
@ -469,47 +460,39 @@ public:
insertVOffset[VCENTER] = tileHeight[TOP];
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);
auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
// Define the strides for the tensor.extract_slice operations.
// Using a negative stride for a dimension reverses the order
// of the extracted elements as necessary for the reflection.
Value extractHStride[3];
extractHStride[LEFT] = minusOne;
extractHStride[HCENTER] = one;
extractHStride[RIGHT] = minusOne;
Value extractVStride[3];
extractVStride[TOP] = minusOne;
extractVStride[VCENTER] = one;
extractVStride[BOTTOM] = minusOne;
SmallVector<utils::IteratorType> iteratorTypes{
numDims, utils::IteratorType::parallel};
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
SmallVector<Value> allOneStrides(numDims, one);
auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
// Create the tile by extracting a slice from the input tenor.
SmallVector<Value> extractShape{inputShape};
extractShape[hDim] = tileWidth[horizontalPos];
extractShape[vDim] = tileHeight[verticalPos];
SmallVector<Value> tileShape{inputShape};
tileShape[hDim] = tileWidth[horizontalPos];
tileShape[vDim] = tileHeight[verticalPos];
// Create the tile by extracting a slice from the input tenor.
SmallVector<Value> extractOffsets(numDims, zero);
extractOffsets[hDim] = extractHOffset[horizontalPos];
extractOffsets[vDim] = extractVOffset[verticalPos];
SmallVector<Value> extractStrides(numDims, one);
extractStrides[hDim] = extractHStride[horizontalPos];
extractStrides[vDim] = extractVStride[verticalPos];
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsets, extractShape, allOneStrides);
// Reverse the tile along the horizontal, vertical, or both
// dimensions.
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
if (shouldHReflect(horizontalPos)) {
inputMap =
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
}
if (shouldVReflect(verticalPos)) {
inputMap =
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
}
tile = rewriter
.create<linalg::GenericOp>(
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
loc, input, extractOffsets, tileShape, extractStrides);
// Insert the tile in the resultTensor.
SmallVector<Value> insertOffsets(numDims, zero);
@ -517,7 +500,7 @@ public:
insertOffsets[vDim] = insertVOffset[verticalPos];
resultTensor = rewriter.create<tensor::InsertSliceOp>(
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
loc, tile, resultTensor, insertOffsets, tileShape, allOneStrides);
};
for (auto v : {TOP, BOTTOM, VCENTER})