mirror of https://github.com/llvm/torch-mlir
parent
4fb58002ab
commit
3f603470a9
|
@ -348,10 +348,6 @@ public:
|
|||
return true;
|
||||
};
|
||||
|
||||
Type indexType = rewriter.getIndexType();
|
||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
||||
|
@ -367,6 +363,15 @@ public:
|
|||
Value hDimSize = inputShape[hDim];
|
||||
Value vDimSize = inputShape[vDim];
|
||||
|
||||
assert(getHPadArgument(LEFT) < hDimSize && "Left padding too large");
|
||||
assert(getHPadArgument(RIGHT) < hDimSize && "Right padding too large");
|
||||
assert(getVPadArgument(TOP) < vDimSize && "Top padding too large");
|
||||
assert(getVPadArgument(BOTTOM) < vDimSize && "Bottom padding too large");
|
||||
|
||||
Type indexType = rewriter.getIndexType();
|
||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
|
||||
Value tileWidth[3];
|
||||
tileWidth[HCENTER] = hDimSize;
|
||||
for (auto h : {LEFT, RIGHT})
|
||||
|
@ -415,9 +420,9 @@ public:
|
|||
// x_n,1 x_n,2 ... x_n,m
|
||||
//
|
||||
// The padding tile consists of the columns 2, ..., m + 1
|
||||
// of the input in reverse order. The first columns
|
||||
// gets skipped because this this is the column trough
|
||||
// which the reflection happens.
|
||||
// of the input in reverse order. The first column gets
|
||||
// skipped because this is the column through which the
|
||||
// reflection happens.
|
||||
//
|
||||
// x_1,m x_1,m-1 ... x_1,2
|
||||
// x_2,m x_1,m-1 ... x_2,2
|
||||
|
@ -428,9 +433,9 @@ public:
|
|||
//
|
||||
// The tile will be inserted to the left of the copy of the input tensor
|
||||
// in the output tensor, i.e. with horizontal offset 0.
|
||||
// If amount of top padding determines the vertical offset.
|
||||
// The top padding determines the vertical offset.
|
||||
|
||||
// Tiles tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
|
||||
// Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
|
||||
// two sides, i.e. their columns and rows must be reversed.
|
||||
|
||||
// Setup information about the tiles
|
||||
|
@ -484,12 +489,14 @@ public:
|
|||
// Reverse the tile along the horizontal, vertical, or both
|
||||
// dimensions
|
||||
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
if (shouldHReflect(horizontalPos))
|
||||
if (shouldHReflect(horizontalPos)) {
|
||||
inputMap =
|
||||
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
|
||||
if (shouldVReflect(verticalPos))
|
||||
}
|
||||
if (shouldVReflect(verticalPos)) {
|
||||
inputMap =
|
||||
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
|
||||
}
|
||||
|
||||
tile = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
|
|
|
@ -1297,6 +1297,8 @@ def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List
|
|||
assert len(self) >= 2
|
||||
vdim = self[-2]
|
||||
hdim = self[-1]
|
||||
|
||||
assert len(padding) == 4, 'padding size expected to be 4'
|
||||
padding_left = padding[0]
|
||||
padding_right = padding[1]
|
||||
padding_top = padding[2]
|
||||
|
@ -1860,9 +1862,6 @@ def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L
|
|||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])])
|
||||
def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
||||
assert len(padding) == 4, 'padding size expected to be 4'
|
||||
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
|
|
Loading…
Reference in New Issue