mirror of https://github.com/llvm/torch-mlir
Simplify lowering of aten.reflection_pad2d to linalg
The lowering uses a linalg.extract_slice operation to extract the input slices to be used for the padding followed by a linalg.generic operation to reverse the order of elements as required for the reflection. This commit removes the use of the linalg.generic. It uses the capability of the linalg.extract_slice operation to extract a reversed slice by specifying negative strides and offsets that point to the end of the slice in the input tensor. This simplifies the generated IR and makes it potentially more optimization friendly.pull/2772/head
parent
5862854bc8
commit
9f5ddb78c7
|
@ -247,8 +247,7 @@ public:
|
|||
namespace {
|
||||
|
||||
// Lower the aten.reflection.pad_2d operator into a sequence of
|
||||
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
|
||||
// operations.
|
||||
// tensor.extract_slice and tensor_insert_slice operations.
|
||||
|
||||
// To understand the lowering, consider this pytorch example:
|
||||
//
|
||||
|
@ -282,8 +281,8 @@ namespace {
|
|||
// center right: [[2,1]]
|
||||
//
|
||||
// The lowering uses a tensor.extract_slice operation to create each tile,
|
||||
// a linalg.generic for the reflection, and a tensor.insert_slice to
|
||||
// insert the tile in the resulting tensor.
|
||||
// including the reversal of the order of the elements if necessary,
|
||||
// and a tensor.insert_slice to insert the tile in the result tensor.
|
||||
class ConvertAtenReflectionPad2dOp
|
||||
: public OpConversionPattern<AtenReflectionPad2dOp> {
|
||||
public:
|
||||
|
@ -305,20 +304,12 @@ public:
|
|||
return rewriter.create<arith::AddIOp>(loc, x, y);
|
||||
};
|
||||
|
||||
auto createAdds = [&](std::initializer_list<Value> values) {
|
||||
assert(values.size() >= 2);
|
||||
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
|
||||
createAdd);
|
||||
};
|
||||
|
||||
auto createSub = [&](Value x, Value y) {
|
||||
return rewriter.create<arith::SubIOp>(loc, x, y);
|
||||
};
|
||||
|
||||
auto createSubs = [&](std::initializer_list<Value> values) {
|
||||
assert(values.size() >= 2);
|
||||
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
|
||||
createSub);
|
||||
auto getIndexConst = [&](int c) {
|
||||
return rewriter.create<arith::ConstantIndexOp>(loc, c);
|
||||
};
|
||||
|
||||
// Enums for specifying the coordinates of a tile. An "h" prefix
|
||||
|
@ -349,7 +340,6 @@ public:
|
|||
};
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
||||
auto outputType = llvm::cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
|
@ -372,37 +362,27 @@ public:
|
|||
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
|
||||
"Bottom padding too large");
|
||||
|
||||
Type indexType = rewriter.getIndexType();
|
||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
Value zero = getIndexConst(0);
|
||||
Value one = getIndexConst(1);
|
||||
Value two = getIndexConst(2);
|
||||
Value minusOne = getIndexConst(-1);
|
||||
|
||||
Value tileWidth[3];
|
||||
tileWidth[HCENTER] = hDimSize;
|
||||
for (auto h : {LEFT, RIGHT})
|
||||
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);
|
||||
tileWidth[h] = getIndexConst(getHPadArgument(h));
|
||||
|
||||
Value tileHeight[3];
|
||||
tileHeight[VCENTER] = vDimSize;
|
||||
for (auto v : {TOP, BOTTOM})
|
||||
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);
|
||||
|
||||
// Helper to reflect/reverse the i-th dimension of an affine map
|
||||
// without symbols. This only works if applied on a tensor
|
||||
// for which the corresponding dimension has a statically
|
||||
// known size which is good enough since we only apply
|
||||
// it to reflect the padding slices.
|
||||
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
|
||||
int64_t size) {
|
||||
AffineExpr d = map.getResult(i);
|
||||
return map.replace(d, size - d - 1, numDims, 0);
|
||||
};
|
||||
tileHeight[v] = getIndexConst(getVPadArgument(v));
|
||||
|
||||
// Create output shape and tensor
|
||||
SmallVector<Value> resultShape{inputShape};
|
||||
resultShape[vDim] =
|
||||
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
|
||||
resultShape[hDim] =
|
||||
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
|
||||
resultShape[vDim] = createAdd(createAdd(resultShape[vDim], tileHeight[TOP]),
|
||||
tileHeight[BOTTOM]);
|
||||
resultShape[hDim] = createAdd(createAdd(resultShape[hDim], tileWidth[LEFT]),
|
||||
tileWidth[RIGHT]);
|
||||
|
||||
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
|
||||
inputType.getElementType());
|
||||
|
@ -444,18 +424,29 @@ public:
|
|||
|
||||
// Setup information about the tiles
|
||||
|
||||
// Compute the offsets for extracting the slice from the
|
||||
// input. We need to skip the row or column through which
|
||||
// the tile should be reflected, if any (none for the center tile).
|
||||
// Compute the offsets for extracting the slice from the input. To
|
||||
// reverse the order of the elements in the non-central tiles,
|
||||
// extract the slices with negative strides and start from the
|
||||
// last element of the input that should belong to the slice,
|
||||
// skipping the "axis" element through which the elements are
|
||||
// reflected:
|
||||
//
|
||||
// - The left tile is obtained by extracting elements
|
||||
// tileWidth[LEFT] + 1, ..., 2 in this, i.e. reverse order.
|
||||
//
|
||||
// - The right tile is obtained by extracting elements
|
||||
// hDimSize - 1, ..., hDimSize - tileWidth[RIGHT] - 1 in this,
|
||||
// i.e. reverse order.
|
||||
|
||||
Value extractHOffset[3];
|
||||
extractHOffset[LEFT] = one;
|
||||
extractHOffset[LEFT] = tileWidth[LEFT];
|
||||
extractHOffset[HCENTER] = zero;
|
||||
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});
|
||||
extractHOffset[RIGHT] = createSub(hDimSize, two);
|
||||
|
||||
Value extractVOffset[3];
|
||||
extractVOffset[TOP] = one;
|
||||
extractVOffset[TOP] = tileHeight[TOP];
|
||||
extractVOffset[VCENTER] = zero;
|
||||
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});
|
||||
extractVOffset[BOTTOM] = createSub(vDimSize, two);
|
||||
|
||||
// Compute the horizontal and vertical offsets for inserting
|
||||
// the tiles in the resultTensor.
|
||||
|
@ -469,47 +460,39 @@ public:
|
|||
insertVOffset[VCENTER] = tileHeight[TOP];
|
||||
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);
|
||||
|
||||
auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
|
||||
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
|
||||
// Define the strides for the tensor.extract_slice operations.
|
||||
// Using a negative stride for a dimension reverses the order
|
||||
// of the extracted elements as necessary for the reflection.
|
||||
Value extractHStride[3];
|
||||
extractHStride[LEFT] = minusOne;
|
||||
extractHStride[HCENTER] = one;
|
||||
extractHStride[RIGHT] = minusOne;
|
||||
|
||||
Value extractVStride[3];
|
||||
extractVStride[TOP] = minusOne;
|
||||
extractVStride[VCENTER] = one;
|
||||
extractVStride[BOTTOM] = minusOne;
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes{
|
||||
numDims, utils::IteratorType::parallel};
|
||||
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
SmallVector<Value> allOneStrides(numDims, one);
|
||||
|
||||
auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
|
||||
// Create the tile by extracting a slice from the input tenor.
|
||||
SmallVector<Value> extractShape{inputShape};
|
||||
extractShape[hDim] = tileWidth[horizontalPos];
|
||||
extractShape[vDim] = tileHeight[verticalPos];
|
||||
SmallVector<Value> tileShape{inputShape};
|
||||
tileShape[hDim] = tileWidth[horizontalPos];
|
||||
tileShape[vDim] = tileHeight[verticalPos];
|
||||
|
||||
// Create the tile by extracting a slice from the input tenor.
|
||||
SmallVector<Value> extractOffsets(numDims, zero);
|
||||
extractOffsets[hDim] = extractHOffset[horizontalPos];
|
||||
extractOffsets[vDim] = extractVOffset[verticalPos];
|
||||
|
||||
SmallVector<Value> extractStrides(numDims, one);
|
||||
extractStrides[hDim] = extractHStride[horizontalPos];
|
||||
extractStrides[vDim] = extractVStride[verticalPos];
|
||||
|
||||
Value tile = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, extractOffsets, extractShape, allOneStrides);
|
||||
|
||||
// Reverse the tile along the horizontal, vertical, or both
|
||||
// dimensions.
|
||||
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
if (shouldHReflect(horizontalPos)) {
|
||||
inputMap =
|
||||
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
|
||||
}
|
||||
if (shouldVReflect(verticalPos)) {
|
||||
inputMap =
|
||||
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
|
||||
}
|
||||
|
||||
tile = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
|
||||
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
|
||||
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
loc, input, extractOffsets, tileShape, extractStrides);
|
||||
|
||||
// Insert the tile in the resultTensor.
|
||||
SmallVector<Value> insertOffsets(numDims, zero);
|
||||
|
@ -517,7 +500,7 @@ public:
|
|||
insertOffsets[vDim] = insertVOffset[verticalPos];
|
||||
|
||||
resultTensor = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
|
||||
loc, tile, resultTensor, insertOffsets, tileShape, allOneStrides);
|
||||
};
|
||||
|
||||
for (auto v : {TOP, BOTTOM, VCENTER})
|
||||
|
|
Loading…
Reference in New Issue