Implement aten.reflection_pad2d lowering to linalg

pre_fixup_20240110
Frederik Harwath 2023-12-22 06:25:15 -08:00 committed by Frederik Harwath
parent aee1fca251
commit 345dfd5903
7 changed files with 520 additions and 0 deletions

View File

@ -7893,6 +7893,30 @@ def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
}]; }];
} }
def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenPadOp : Torch_Op<"aten.pad", [ def Torch_AtenPadOp : Torch_Op<"aten.pad", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -244,6 +244,283 @@ public:
}; };
} }
namespace {
// Lower the aten.reflection.pad_2d operator into a sequence of
// tensor.extract_slice, linalg.generic, and tensor_insert_slice
// operations.
// To understand the lowering, consider this pytorch example:
//
// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]])
// >>> t
// tensor([[[1., 2., 3.],
// [4., 5., 6.],
// [7., 8., 9.]]])
// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2])
// tensor([[[5., 4., 5., 6., 5., 4.],
// [2., 1., 2., 3., 2., 1.],
// [5., 4., 5., 6., 5., 4.],
// [8., 7., 8., 9., 8., 7.],
// [5., 4., 5., 6., 5., 4.],
// [2., 1., 2., 3., 2., 1.]]])
//
// The result can be subdivided into "tiles" corresponding to either
// the input tensor (in the center) or slices of the input tensor
// whose width and height is determined by the padding sizes and which
// are reflected through the side of the central input tensor that
// they touch.
// In the example above, the tiles are:
// top left: [[5]]
// top center: [[4,5,6]]
// top right: [[5,4]]
// center left [[2,1],[5,4],[8,7]]
// center: copy of the input tensor
// center right: [[2,1],[5,4],[8,7]]
// bottom left: [[5,4],[2,1]]
// center bottom: [[2,3,2]]
// center right: [[2,1]]
//
// The lowering uses a tensor.extract_slice operation to create each tile,
// a linalg.generic for the reflection, and a tensor.insert_slice to
// insert the tile in the resulting tensor.
class ConvertAtenReflectionPad2dOp
: public OpConversionPattern<AtenReflectionPad2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
SmallVector<int64_t> padInts;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");
Location loc = op.getLoc();
// Some generic helper functions for creating arithmetic operations.
auto createAdd = [&](Value x, Value y) {
return rewriter.create<arith::AddIOp>(loc, x, y);
};
auto createAdds = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createAdd);
};
auto createSub = [&](Value x, Value y) {
return rewriter.create<arith::SubIOp>(loc, x, y);
};
auto createSubs = [&](std::initializer_list<Value> values) {
assert(values.size() >= 2);
return std::accumulate(values.begin() + 1, values.end(), data(values)[0],
createSub);
};
// Enums for specifying the coordinates of a tile. An "h" prefix
// is used to stand for "horizontal" and "v" for "vertical"
// throughout.
enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 };
enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 };
// Helper functions for obtaining information about the operator's
// padding arguments.
auto getHPadArgument = [&](PadHLoc l) {
assert(l < HCENTER);
return padInts[l];
};
auto getVPadArgument = [&](PadVLoc l) {
assert(l < VCENTER);
return padInts[2 + l];
};
auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) {
if (!(h == HCENTER || getHPadArgument(h) > 0))
return false;
if (!(v == VCENTER || getVPadArgument(v) > 0))
return false;
return true;
};
Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);
Value input = adaptor.getSelf();
MLIRContext *context = rewriter.getContext();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
auto outputType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
unsigned numDims = inputType.getRank();
assert(numDims >= 2 && "Not enough input dimensions");
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
int64_t hDim = numDims - 1;
int64_t vDim = numDims - 2;
Value hDimSize = inputShape[hDim];
Value vDimSize = inputShape[vDim];
Value tileWidth[3];
tileWidth[HCENTER] = hDimSize;
for (auto h : {LEFT, RIGHT})
tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType);
Value tileHeight[3];
tileHeight[VCENTER] = vDimSize;
for (auto v : {TOP, BOTTOM})
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);
// Helper to reflect/reverse the i-th dimension of an affine map
// without symbols. This only works if applied on a tensor
// for which the corresponding dimension has a statically
// known size which is good enough since we only apply
// it to reflect the padding slices.
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
int64_t size) {
AffineExpr d = map.getResult(i);
return map.replace(d, size - d - 1, numDims, 0);
};
// Create output shape and tensor
SmallVector<Value> resultShape{inputShape};
resultShape[vDim] =
createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]});
resultShape[hDim] =
createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]});
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
inputType.getElementType());
// Construction of the tiles
// Example: central left tile
//
// Let m the width of the left padding as returned by getHPadargument(LEFT)
// and n the size of the input tensor's "horizontal" dimension, i.e.
// hDimSize. Assume that the subtensor of the input tensor in the relevant
// (i.e. last two) dimensions is:
//
// x_1,1 x_1,2 ... x_1,m
// x_2,1 x_2,2 ... x_2,m
// .
// .
// .
// x_n,1 x_n,2 ... x_n,m
//
// The padding tile consists of the columns 2, ..., m + 1
// of the input in reverse order. The first columns
// gets skipped because this this is the column trough
// which the reflection happens.
//
// x_1,m x_1,m-1 ... x_1,2
// x_2,m x_1,m-1 ... x_2,2
// .
// .
// .
// x_n,m x_n,m-1 ... x_n,2
//
// The tile will be inserted to the left of the copy of the input tensor
// in the output tensor, i.e. with horizontal offset 0.
// If amount of top padding determines the vertical offset.
// Tiles tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through
// two sides, i.e. their columns and rows must be reversed.
// Setup information about the tiles
// Compute the offsets for extracting the slice from the
// input. We need to skip the row or column through which
// the tile should be reflected, if any (none for the center tile).
Value extractHOffset[3];
extractHOffset[LEFT] = one;
extractHOffset[HCENTER] = zero;
extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one});
Value extractVOffset[3];
extractVOffset[TOP] = one;
extractVOffset[VCENTER] = zero;
extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one});
// Compute the horizontal and vertical offsets for inserting
// the tiles in the resultTensor.
Value insertHOffset[3];
insertHOffset[LEFT] = zero;
insertHOffset[HCENTER] = tileWidth[LEFT];
insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]);
Value insertVOffset[3];
insertVOffset[TOP] = zero;
insertVOffset[VCENTER] = tileHeight[TOP];
insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]);
auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; };
auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; };
SmallVector<utils::IteratorType> iteratorTypes{
numDims, utils::IteratorType::parallel};
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
SmallVector<Value> allOneStrides(numDims, one);
auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) {
// Create the tile by extracting a slice from the input tenor.
SmallVector<Value> extractShape{inputShape};
extractShape[hDim] = tileWidth[horizontalPos];
extractShape[vDim] = tileHeight[verticalPos];
SmallVector<Value> extractOffsets(numDims, zero);
extractOffsets[hDim] = extractHOffset[horizontalPos];
extractOffsets[vDim] = extractVOffset[verticalPos];
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsets, extractShape, allOneStrides);
// Reverse the tile along the horizontal, vertical, or both
// dimensions
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
if (shouldHReflect(horizontalPos))
inputMap =
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
if (shouldVReflect(verticalPos))
inputMap =
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
tile = rewriter
.create<linalg::GenericOp>(
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);
// Insert the tile in the resultTensor
SmallVector<Value> insertOffsets(numDims, zero);
insertOffsets[hDim] = insertHOffset[horizontalPos];
insertOffsets[vDim] = insertVOffset[verticalPos];
resultTensor = rewriter.create<tensor::InsertSliceOp>(
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
};
for (auto v : {TOP, BOTTOM, VCENTER})
for (auto h : {LEFT, RIGHT, HCENTER})
if (shouldCreateTile(v, h))
createTile(v, h);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenFlattenUsingIntsOp class ConvertAtenFlattenUsingIntsOp
: public OpConversionPattern<AtenFlattenUsingIntsOp> { : public OpConversionPattern<AtenFlattenUsingIntsOp> {
@ -1552,6 +1829,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenReflectionPad1dOp>(); target.addIllegalOp<AtenReflectionPad1dOp>();
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context); patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
target.addIllegalOp<AtenReflectionPad2dOp>();
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
target.addIllegalOp<AtenFlattenUsingIntsOp>(); target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context); patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
target.addIllegalOp<AtenViewOp>(); target.addIllegalOp<AtenViewOp>();

View File

@ -8366,6 +8366,59 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %7 : !torch.list<int>\n" " return %7 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %7 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.lt.int %4, %3 : !torch.int, !torch.int -> !torch.bool\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" %13 = torch.aten.lt.int %5, %3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %10 = torch.aten.lt.int %6, %2 : !torch.int, !torch.int -> !torch.bool\n"
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
" %13 = torch.aten.lt.int %7, %2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %11 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %12 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %12 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {\n"
" %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -9002,6 +9055,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.eq.int %1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -1286,6 +1286,26 @@ def atenreflection_pad1d〡shape(self: List[int], padding: List[int]) -> List
assert padding_left < hdim and padding_right < hdim assert padding_left < hdim and padding_right < hdim
return pad_shape_fn(self, padding) return pad_shape_fn(self, padding)
# Padding size must be smaller than corresponding dimension
@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]),
Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]),
ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])])
def atenreflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]:
assert len(self) >= 2
vdim = self[-2]
hdim = self[-1]
padding_left = padding[0]
padding_right = padding[1]
padding_top = padding[2]
padding_bottom = padding[3]
assert padding_left < hdim and padding_right < hdim
assert padding_top < vdim and padding_bottom < vdim
return pad_shape_fn(self, padding)
# TODO: upstream this # TODO: upstream this
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
assert len(indices) <= len(self), "More indices than dimensions to index" assert len(indices) <= len(self), "More indices than dimensions to index"
@ -1831,6 +1851,20 @@ def atenreflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L
assert len(padding) == 2, 'padding size expected to be 2' assert len(padding) == 2, 'padding size expected to be 2'
return self_dtype return self_dtype
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2,1]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1]),
Invocation(TensorOfShape(5, 5, 4), padding=[1,2,3,4]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[5,4,3,2,1])])
def atenreflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
assert len(padding) == 4, 'padding size expected to be 4'
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atencontiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: def atencontiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# Misc tensor ops. # Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)")
emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)")
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)

View File

@ -59,3 +59,4 @@ def register_all_tests():
from . import return_types from . import return_types
from . import control_flow from . import control_flow
from . import stats from . import stats
from . import padding

View File

@ -0,0 +1,113 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import functorch
import torch
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class ReflectionPad2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 20, 20], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.reflection_pad2d(x, (10,10,10,10))
@register_test_case(module_factory=lambda: ReflectionPad2dModule())
def ReflectionPad2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 20, 20, low=-1))
# ==============================================================================
class ReflectionPad2dModuleTop(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 3, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.reflection_pad2d(x, (0,0,2,0))
@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop())
def ReflectionPad2dModule_Top(module, tu: TestUtils):
module.forward(tu.rand(1, 3, 4))
# ==============================================================================
class ReflectionPad2dModuleBottom(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 10, 10], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.reflection_pad2d(x, (0,0,0,5))
@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom())
def ReflectionPad2dModule_Bottom(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 10, 10))
# ==============================================================================
class ReflectionPad2dModuleLeft(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 20, 20], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.reflection_pad2d(x, (15,0,0,0))
@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft())
def ReflectionPad2dModule_Left(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 20, 20))
# ==============================================================================
class ReflectionPad2dModuleRight(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 20, 20], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.reflection_pad2d(x, (0,11,0,0))
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
def ReflectionPad2dModule_Right(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 20, 20))
# ==============================================================================