mirror of https://github.com/llvm/torch-mlir
Implement aten.reflection_pad2d lowering to linalg
parent
aee1fca251
commit
345dfd5903
|
@ -7893,6 +7893,30 @@ def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$padding
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPadOp : Torch_Op<"aten.pad", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -244,6 +244,283 @@ public:
|
|||
};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Lower the aten.reflection.pad_2d operator into a sequence of
|
||||
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
|
||||
// operations.
|
||||
|
||||
// To understand the lowering, consider this pytorch example:
|
||||
//
|
||||
// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]])
|
||||
// >>> t
|
||||
// tensor([[[1., 2., 3.],
|
||||
// [4., 5., 6.],
|
||||
// [7., 8., 9.]]])
|
||||
// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2])
|
||||
// tensor([[[5., 4., 5., 6., 5., 4.],
|
||||
// [2., 1., 2., 3., 2., 1.],
|
||||
// [5., 4., 5., 6., 5., 4.],
|
||||
// [8., 7., 8., 9., 8., 7.],
|
||||
// [5., 4., 5., 6., 5., 4.],
|
||||
// [2., 1., 2., 3., 2., 1.]]])
|
||||
//
|
||||
// The result can be subdivided into "tiles" corresponding to either
|
||||
// the input tensor (in the center) or slices of the input tensor
|
||||
// whose width and height is determined by the padding sizes and which
|
||||
// are reflected through the side of the central input tensor that
|
||||
// they touch.
|
||||
// In the example above, the tiles are:
|
||||
// top left: [[5]]
|
||||
// top center: [[4,5,6]]
|
||||
// top right: [[5,4]]
|
||||
// center left [[2,1],[5,4],[8,7]]
|
||||
// center: copy of the input tensor
|
||||
// center right: [[2,1],[5,4],[8,7]]
|
||||
// bottom left: [[5,4],[2,1]]
|
||||
// center bottom: [[2,3,2]]
|
||||
// center right: [[2,1]]
|
||||
//
|
||||
// The lowering uses a tensor.extract_slice operation to create each tile,
|
||||
// a linalg.generic for the reflection, and a tensor.insert_slice to
|
||||
// insert the tile in the resulting tensor.
|
||||
class ConvertAtenReflectionPad2dOp
|
||||
: public OpConversionPattern<AtenReflectionPad2dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant int pad ranges");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
// Some generic helper functions for creating arithmetic operations.
|
||||
auto createAdd = [&](Value x, Value y) {
|
||||
return rewriter.create<arith::AddIOp>(loc, x, y);
|
||||
};
|
||||
|
||||
auto createAdds = [&](std::initializer_list<Value> values) {
|
||||
assert(values.size() >= 2);
|
||||
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
|
||||
createAdd);
|
||||
};
|
||||
|
||||
auto createSub = [&](Value x, Value y) {
|
||||
return rewriter.create<arith::SubIOp>(loc, x, y);
|
||||
};
|
||||
|
||||
auto createSubs = [&](std::initializer_list<Value> values) {
|
||||
assert(values.size() >= 2);
|
||||
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
|
||||
createSub);
|
||||
};
|
||||
|
||||
// Enums for specifying the coordinates of a tile. An "h" prefix
|
||||
// is used to stand for "horizontal" and "v" for "vertical"
|
||||
// throughout.
|
||||
enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 };
|
||||
enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 };
|
||||
|
||||
// Helper functions for obtaining information about the operator's
|
||||
// padding arguments.
|
||||
auto getHPadArgument = [&](PadHLoc l) {
|
||||
assert(l < HCENTER);
|
||||
return padInts[l];
|
||||
};
|
||||
|
||||
auto getVPadArgument = [&](PadVLoc l) {
|
||||
assert(l < VCENTER);
|
||||
return padInts[2 + l];
|
||||
};
|
||||
|
||||
auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) {
|
||||
if (!(h == HCENTER || getHPadArgument(h) > 0))
|
||||
return false;
|
||||
if (!(v == VCENTER || getVPadArgument(v) > 0))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
Type indexType = rewriter.getIndexType();
|
||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
||||
auto outputType = llvm::cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
unsigned numDims = inputType.getRank();
|
||||
|
||||
assert(numDims >= 2 && "Not enough input dimensions");
|
||||
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
int64_t hDim = numDims - 1;
|
||||
int64_t vDim = numDims - 2;
|
||||
Value hDimSize = inputShape[hDim];
|
||||
Value vDimSize = inputShape[vDim];
|
||||
|
||||
Value tileWidth[3];
|
||||
tileWidth[HCENTER] = hDimSize;
|
||||
for (auto h : {LEFT, RIGHT})
|
||||
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);
|
||||
|
||||
Value tileHeight[3];
|
||||
tileHeight[VCENTER] = vDimSize;
|
||||
for (auto v : {TOP, BOTTOM})
|
||||
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);
|
||||
|
||||
// Helper to reflect/reverse the i-th dimension of an affine map
|
||||
// without symbols. This only works if applied on a tensor
|
||||
// for which the corresponding dimension has a statically
|
||||
// known size which is good enough since we only apply
|
||||
// it to reflect the padding slices.
|
||||
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
|
||||
int64_t size) {
|
||||
AffineExpr d = map.getResult(i);
|
||||
return map.replace(d, size - d - 1, numDims, 0);
|
||||
};
|
||||
|
||||
// Create output shape and tensor
|
||||
SmallVector<Value> resultShape{inputShape};
|
||||
resultShape[vDim] =
|
||||
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
|
||||
resultShape[hDim] =
|
||||
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
|
||||
|
||||
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
|
||||
inputType.getElementType());
|
||||
|
||||
// Construction of the tiles
|
||||
|
||||
// Example: central left tile
|
||||
//
|
||||
// Let m the width of the left padding as returned by getHPadargument(LEFT)
|
||||
// and n the size of the input tensor's "horizontal" dimension, i.e.
|
||||
// hDimSize. Assume that the subtensor of the input tensor in the relevant
|
||||
// (i.e. last two) dimensions is:
|
||||
//
|
||||
// x_1,1 x_1,2 ... x_1,m
|
||||
// x_2,1 x_2,2 ... x_2,m
|
||||
// .
|
||||
// .
|
||||
// .
|
||||
// x_n,1 x_n,2 ... x_n,m
|
||||
//
|
||||
// The padding tile consists of the columns 2, ..., m + 1
|
||||
// of the input in reverse order. The first columns
|
||||
// gets skipped because this this is the column trough
|
||||
// which the reflection happens.
|
||||
//
|
||||
// x_1,m x_1,m-1 ... x_1,2
|
||||
// x_2,m x_1,m-1 ... x_2,2
|
||||
// .
|
||||
// .
|
||||
// .
|
||||
// x_n,m x_n,m-1 ... x_n,2
|
||||
//
|
||||
// The tile will be inserted to the left of the copy of the input tensor
|
||||
// in the output tensor, i.e. with horizontal offset 0.
|
||||
// If amount of top padding determines the vertical offset.
|
||||
|
||||
// Tiles tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
|
||||
// two sides, i.e. their columns and rows must be reversed.
|
||||
|
||||
// Setup information about the tiles
|
||||
|
||||
// Compute the offsets for extracting the slice from the
|
||||
// input. We need to skip the row or column through which
|
||||
// the tile should be reflected, if any (none for the center tile).
|
||||
Value extractHOffset[3];
|
||||
extractHOffset[LEFT] = one;
|
||||
extractHOffset[HCENTER] = zero;
|
||||
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});
|
||||
|
||||
Value extractVOffset[3];
|
||||
extractVOffset[TOP] = one;
|
||||
extractVOffset[VCENTER] = zero;
|
||||
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});
|
||||
|
||||
// Compute the horizontal and vertical offsets for inserting
|
||||
// the tiles in the resultTensor.
|
||||
Value insertHOffset[3];
|
||||
insertHOffset[LEFT] = zero;
|
||||
insertHOffset[HCENTER] = tileWidth[LEFT];
|
||||
insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]);
|
||||
|
||||
Value insertVOffset[3];
|
||||
insertVOffset[TOP] = zero;
|
||||
insertVOffset[VCENTER] = tileHeight[TOP];
|
||||
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);
|
||||
|
||||
auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
|
||||
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes{
|
||||
numDims, utils::IteratorType::parallel};
|
||||
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
SmallVector<Value> allOneStrides(numDims, one);
|
||||
|
||||
auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
|
||||
// Create the tile by extracting a slice from the input tenor.
|
||||
SmallVector<Value> extractShape{inputShape};
|
||||
extractShape[hDim] = tileWidth[horizontalPos];
|
||||
extractShape[vDim] = tileHeight[verticalPos];
|
||||
|
||||
SmallVector<Value> extractOffsets(numDims, zero);
|
||||
extractOffsets[hDim] = extractHOffset[horizontalPos];
|
||||
extractOffsets[vDim] = extractVOffset[verticalPos];
|
||||
|
||||
Value tile = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, extractOffsets, extractShape, allOneStrides);
|
||||
|
||||
// Reverse the tile along the horizontal, vertical, or both
|
||||
// dimensions
|
||||
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||
if (shouldHReflect(horizontalPos))
|
||||
inputMap =
|
||||
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
|
||||
if (shouldVReflect(verticalPos))
|
||||
inputMap =
|
||||
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
|
||||
|
||||
tile = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
|
||||
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
|
||||
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// Insert the tile in the resultTensor
|
||||
SmallVector<Value> insertOffsets(numDims, zero);
|
||||
insertOffsets[hDim] = insertHOffset[horizontalPos];
|
||||
insertOffsets[vDim] = insertVOffset[verticalPos];
|
||||
|
||||
resultTensor = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
|
||||
};
|
||||
|
||||
for (auto v : {TOP, BOTTOM, VCENTER})
|
||||
for (auto h : {LEFT, RIGHT, HCENTER})
|
||||
if (shouldCreateTile(v, h))
|
||||
createTile(v, h);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFlattenUsingIntsOp
|
||||
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
||||
|
@ -1552,6 +1829,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenReflectionPad1dOp>();
|
||||
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenReflectionPad2dOp>();
|
||||
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewOp>();
|
||||
|
|
|
@ -8366,6 +8366,59 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %7 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %int-2 = torch.constant.int -2\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %7 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %8 = torch.aten.lt.int %4, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
||||
" %13 = torch.aten.lt.int %5, %3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %9 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %10 = torch.aten.lt.int %6, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
|
||||
" %13 = torch.aten.lt.int %7, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %11 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %12 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %12 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9002,6 +9055,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %2 = torch.aten.eq.int %1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -1286,6 +1286,26 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List
|
|||
assert padding_left < hdim and padding_right < hdim
|
||||
return pad_shape_fn(self, padding)
|
||||
|
||||
|
||||
# Padding size must be smaller than corresponding dimension
|
||||
@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]),
|
||||
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]),
|
||||
Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]),
|
||||
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]),
|
||||
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])])
|
||||
def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]:
|
||||
assert len(self) >= 2
|
||||
vdim = self[-2]
|
||||
hdim = self[-1]
|
||||
padding_left = padding[0]
|
||||
padding_right = padding[1]
|
||||
padding_top = padding[2]
|
||||
padding_bottom = padding[3]
|
||||
assert padding_left < hdim and padding_right < hdim
|
||||
assert padding_top < vdim and padding_bottom < vdim
|
||||
|
||||
return pad_shape_fn(self, padding)
|
||||
|
||||
# TODO: upstream this
|
||||
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
||||
assert len(indices) <= len(self), "More indices than dimensions to index"
|
||||
|
@ -1831,6 +1851,20 @@ def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L
|
|||
assert len(padding) == 2, 'padding size expected to be 2'
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2,1]),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1]),
|
||||
Invocation(TensorOfShape(5, 5, 4), padding=[1,2,3,4]),
|
||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])])
|
||||
def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
||||
assert len(padding) == 4, 'padding size expected to be 4'
|
||||
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
|
||||
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
|
||||
|
|
|
@ -59,3 +59,4 @@ def register_all_tests():
|
|||
from . import return_types
|
||||
from . import control_flow
|
||||
from . import stats
|
||||
from . import padding
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import functorch
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 20, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (10,10,10,10))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModule())
|
||||
def ReflectionPad2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 20, 20, low=-1))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleTop(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (0,0,2,0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop())
|
||||
def ReflectionPad2dModule_Top(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleBottom(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 10, 10], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (0,0,0,5))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom())
|
||||
def ReflectionPad2dModule_Bottom(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 10, 10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleLeft(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 20, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (15,0,0,0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft())
|
||||
def ReflectionPad2dModule_Left(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 20, 20))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReflectionPad2dModuleRight(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 20, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.reflection_pad2d(x, (0,11,0,0))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
||||
def ReflectionPad2dModule_Right(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 20, 20))
|
||||
|
||||
# ==============================================================================
|
Loading…
Reference in New Issue