Review changes

Changes suggested by @qedawkins.
pre_fixup_20240110
Frederik Harwath 2024-01-10 00:53:52 -08:00 committed by Frederik Harwath
parent 4fb58002ab
commit 3f603470a9
2 changed files with 20 additions and 14 deletions

View File

@ -348,10 +348,6 @@ public:
return true;
};
Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);
Value input = adaptor.getSelf();
MLIRContext *context = rewriter.getContext();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
@ -367,6 +363,15 @@ public:
Value hDimSize = inputShape[hDim];
Value vDimSize = inputShape[vDim];
assert(getHPadArgument(LEFT) < hDimSize && "Left padding too large");
assert(getHPadArgument(RIGHT) < hDimSize && "Right padding too large");
assert(getVPadArgument(TOP) < vDimSize && "Top padding too large");
assert(getVPadArgument(BOTTOM) < vDimSize && "Bottom padding too large");
Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);
Value tileWidth[3];
tileWidth[HCENTER] = hDimSize;
for (auto h : {LEFT, RIGHT})
@ -415,9 +420,9 @@ public:
// x_n,1 x_n,2 ... x_n,m
//
// The padding tile consists of the columns 2, ..., m + 1
// of the input in reverse order. The first columns
// gets skipped because this this is the column trough
// which the reflection happens.
// of the input in reverse order. The first column gets
// skipped because this is the column through which the
// reflection happens.
//
// x_1,m x_1,m-1 ... x_1,2
// x_2,m x_1,m-1 ... x_2,2
@ -428,9 +433,9 @@ public:
//
// The tile will be inserted to the left of the copy of the input tensor
// in the output tensor, i.e. with horizontal offset 0.
// If amount of top padding determines the vertical offset.
// The top padding determines the vertical offset.
// Tiles tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
// Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
// two sides, i.e. their columns and rows must be reversed.
// Setup information about the tiles
@ -484,12 +489,14 @@ public:
// Reverse the tile along the horizontal, vertical, or both
// dimensions
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
if (shouldHReflect(horizontalPos))
if (shouldHReflect(horizontalPos)) {
inputMap =
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
if (shouldVReflect(verticalPos))
}
if (shouldVReflect(verticalPos)) {
inputMap =
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
}
tile = rewriter
.create<linalg::GenericOp>(

View File

@ -1297,6 +1297,8 @@ def atenreflection_pad2d〡shape(self: List[int], padding: List[int]) -> List
assert len(self) >= 2
vdim = self[-2]
hdim = self[-1]
assert len(padding) == 4, 'padding size expected to be 4'
padding_left = padding[0]
padding_right = padding[1]
padding_top = padding[2]
@ -1860,9 +1862,6 @@ def atenreflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])])
def atenreflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
assert len(padding) == 4, 'padding size expected to be 4'
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))