From 0860c41ee2a0bdec41f544f19eba170cf646c3ce Mon Sep 17 00:00:00 2001 From: Frederik Harwath Date: Fri, 22 Dec 2023 06:25:15 -0800 Subject: [PATCH] Implement aten.reflection_pad2d lowering to linalg --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 290 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 67 ++++ .../base_lazy_backend/shape_inference.cpp | 28 ++ .../build_tools/abstract_interp_lib_gen.py | 29 ++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/__init__.py | 1 + .../torch_mlir_e2e_test/test_suite/padding.py | 113 +++++++ 8 files changed, 553 insertions(+) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 23e65d75d..74a2e2327 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7893,6 +7893,30 @@ def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ }]; } +def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6534e8598..49f5f0ec3 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -244,6 +244,294 @@ public: }; } +namespace { + +// Lower the aten.reflection.pad_2d operator into a sequence of +// tensor.extract_slice, linalg.generic, and tensor_insert_slice +// operations. + +// To understand the lowering, consider this pytorch example: +// +// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]]) +// >>> t +// tensor([[[1., 2., 3.], +// [4., 5., 6.], +// [7., 8., 9.]]]) +// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2]) +// tensor([[[5., 4., 5., 6., 5., 4.], +// [2., 1., 2., 3., 2., 1.], +// [5., 4., 5., 6., 5., 4.], +// [8., 7., 8., 9., 8., 7.], +// [5., 4., 5., 6., 5., 4.], +// [2., 1., 2., 3., 2., 1.]]]) +// +// The result can be subdivided into "tiles" corresponding to either +// the input tensor (in the center) or slices of the input tensor +// whose width and height is determined by the padding sizes and which +// are reflected through the side of the central input tensor that +// they touch. +// In the example above, the tiles are: +// top left: [[5]] +// top center: [[4,5,6]] +// top right: [[5,4]] +// center left [[2,1],[5,4],[8,7]] +// center: copy of the input tensor +// center right: [[2,1],[5,4],[8,7]] +// bottom left: [[5,4],[2,1]] +// center bottom: [[2,3,2]] +// center right: [[2,1]] +// +// The lowering uses a tensor.extract_slice operation to create each tile, +// a linalg.generic for the reflection, and a tensor.insert_slice to +// insert the tile in the resulting tensor. +class ConvertAtenReflectionPad2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + + Location loc = op.getLoc(); + // Some generic helper functions for creating arithmetic operations. + auto createAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createAdds = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createAdd); + }; + + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createSubs = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createSub); + }; + + // Enums for specifying the coordinates of a tile. An "h" prefix + // is used to stand for "horizontal" and "v" for "vertical" + // throughout. + enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 }; + enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 }; + + // Helper functions for obtaining information about the operator's + // padding arguments. + auto getHPadArgument = [&](PadHLoc l) { + assert(l < HCENTER); + return padInts[l]; + }; + + auto getVPadArgument = [&](PadVLoc l) { + assert(l < VCENTER); + return padInts[2 + l]; + }; + + auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) { + if (!(h == HCENTER || getHPadArgument(h) > 0)) + return false; + if (!(v == VCENTER || getVPadArgument(v) > 0)) + return false; + + return true; + }; + + Value input = adaptor.getSelf(); + MLIRContext *context = rewriter.getContext(); + auto inputType = llvm::cast(input.getType()); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; + + assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && + "Left padding too large"); + assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && + "Right padding too large"); + assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && + "Top padding too large"); + assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && + "Bottom padding too large"); + + Type indexType = rewriter.getIndexType(); + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + + Value tileWidth[3]; + tileWidth[HCENTER] = hDimSize; + for (auto h : {LEFT, RIGHT}) + tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType); + + Value tileHeight[3]; + tileHeight[VCENTER] = vDimSize; + for (auto v : {TOP, BOTTOM}) + tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType); + + // Helper to reflect/reverse the i-th dimension of an affine map + // without symbols. This only works if applied on a tensor + // for which the corresponding dimension has a statically + // known size which is good enough since we only apply + // it to reflect the padding slices. + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, + int64_t size) { + AffineExpr d = map.getResult(i); + return map.replace(d, size - d - 1, numDims, 0); + }; + + // Create output shape and tensor + SmallVector resultShape{inputShape}; + resultShape[vDim] = + createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]}); + resultShape[hDim] = + createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]}); + + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, + inputType.getElementType()); + + // Construction of the tiles + + // Example: central left tile + // + // Let m the width of the left padding as returned by getHPadargument(LEFT) + // and n the size of the input tensor's "horizontal" dimension, i.e. + // hDimSize. Assume that the subtensor of the input tensor in the relevant + // (i.e. last two) dimensions is: + // + // x_1,1 x_1,2 ... x_1,m + // x_2,1 x_2,2 ... x_2,m + // . + // . + // . + // x_n,1 x_n,2 ... x_n,m + // + // The padding tile consists of the columns 2, ..., m + 1 + // of the input in reverse order. The first column gets + // skipped because this is the column through which the + // reflection happens. + // + // x_1,m x_1,m-1 ... x_1,2 + // x_2,m x_1,m-1 ... x_2,2 + // . + // . + // . + // x_n,m x_n,m-1 ... x_n,2 + // + // The tile will be inserted to the left of the copy of the input tensor + // in the output tensor, i.e. with horizontal offset 0. + // The top padding determines the vertical offset. + + // Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through + // two sides, i.e. their columns and rows must be reversed. + + // Setup information about the tiles + + // Compute the offsets for extracting the slice from the + // input. We need to skip the row or column through which + // the tile should be reflected, if any (none for the center tile). + Value extractHOffset[3]; + extractHOffset[LEFT] = one; + extractHOffset[HCENTER] = zero; + extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one}); + + Value extractVOffset[3]; + extractVOffset[TOP] = one; + extractVOffset[VCENTER] = zero; + extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one}); + + // Compute the horizontal and vertical offsets for inserting + // the tiles in the resultTensor. + Value insertHOffset[3]; + insertHOffset[LEFT] = zero; + insertHOffset[HCENTER] = tileWidth[LEFT]; + insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]); + + Value insertVOffset[3]; + insertVOffset[TOP] = zero; + insertVOffset[VCENTER] = tileHeight[TOP]; + insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]); + + auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; }; + auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; }; + + SmallVector iteratorTypes{ + numDims, utils::IteratorType::parallel}; + auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); + SmallVector allOneStrides(numDims, one); + + auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) { + // Create the tile by extracting a slice from the input tenor. + SmallVector extractShape{inputShape}; + extractShape[hDim] = tileWidth[horizontalPos]; + extractShape[vDim] = tileHeight[verticalPos]; + + SmallVector extractOffsets(numDims, zero); + extractOffsets[hDim] = extractHOffset[horizontalPos]; + extractOffsets[vDim] = extractVOffset[verticalPos]; + + Value tile = rewriter.create( + loc, input, extractOffsets, extractShape, allOneStrides); + + // Reverse the tile along the horizontal, vertical, or both + // dimensions. + auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); + if (shouldHReflect(horizontalPos)) { + inputMap = + reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos)); + } + if (shouldVReflect(verticalPos)) { + inputMap = + reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos)); + } + + tile = rewriter + .create( + loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); + + // Insert the tile in the resultTensor. + SmallVector insertOffsets(numDims, zero); + insertOffsets[hDim] = insertHOffset[horizontalPos]; + insertOffsets[vDim] = insertVOffset[verticalPos]; + + resultTensor = rewriter.create( + loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + }; + + for (auto v : {TOP, BOTTOM, VCENTER}) + for (auto h : {LEFT, RIGHT, HCENTER}) + if (shouldCreateTile(v, h)) + createTile(v, h); + + rewriter.replaceOpWithNewOp(op, outputType, resultTensor); + + return success(); + } +}; +} // namespace + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -1552,6 +1840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4adf55556..55b9638dd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8366,6 +8366,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %7 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int4 = torch.constant.int 4\n" +" %int0 = torch.constant.int 0\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %14 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9002,6 +9065,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 244ee7b88..3971fdd32 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -227,6 +227,34 @@ std::vector compute_shape_remainder( return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector +compute_shape_reflection_pad2d(const at::Tensor &self, + at::IntArrayRef padding) { + std::vector paddings = padding.vec(); + std::vector in_sizes = self.sizes().vec(); + auto num_dims = in_sizes.size(); + + TORCH_CHECK(padding.size() == 4); + TORCH_CHECK(num_dims >= 2); + + auto vdim = num_dims - 2; + auto hdim = num_dims - 1; + auto padding_left = padding[0]; + auto padding_right = padding[1]; + auto padding_top = padding[2]; + auto padding_bottom = padding[3]; + TORCH_CHECK(padding_left < in_sizes[hdim]); + TORCH_CHECK(padding_right < in_sizes[hdim]); + TORCH_CHECK(padding_top < in_sizes[vdim]); + TORCH_CHECK(padding_bottom < in_sizes[vdim]); + + std::vector out_sizes(in_sizes); + out_sizes[hdim] += padding_left + padding_right; + out_sizes[vdim] += padding_top + padding_bottom; + + return {Shape(self.scalar_type(), out_sizes)}; +} + std::vector compute_shape_uniform( const at::Tensor& self, double from, double to, c10::optional generator) { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 48949c318..a16d778c7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1286,6 +1286,30 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List assert padding_left < hdim and padding_right < hdim return pad_shape_fn(self, padding) + +# Padding size must be smaller than corresponding dimension +@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1]), + Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) +def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + vdim = self[-2] + hdim = self[-1] + + assert len(padding) == 4, 'padding size expected to be 4' + padding_left = padding[0] + padding_right = padding[1] + padding_top = padding[2] + padding_bottom = padding[3] + assert padding_left < hdim and padding_right < hdim + assert padding_top < vdim and padding_bottom < vdim + + return pad_shape_fn(self, padding) + # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" @@ -1831,6 +1855,11 @@ def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L assert len(padding) == 2, 'padding size expected to be 2' return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 2, 2)], padding=[1,1,1,1])) +def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4d5b65c1d..9c0a0759b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 79712a16f..f24266c78 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -59,3 +59,4 @@ def register_all_tests(): from . import return_types from . import control_flow from . import stats + from . import padding diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py new file mode 100644 index 000000000..6b7bdeab2 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import functorch +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class ReflectionPad2dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (10,10,10,10)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModule()) +def ReflectionPad2dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, low=-1)) + +# ============================================================================== + +class ReflectionPad2dModuleTop(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 3, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,0,2,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop()) +def ReflectionPad2dModule_Top(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 4)) + +# ============================================================================== + +class ReflectionPad2dModuleBottom(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 10, 10], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,0,0,5)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom()) +def ReflectionPad2dModule_Bottom(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 10, 10)) + +# ============================================================================== + +class ReflectionPad2dModuleLeft(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (15,0,0,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft()) +def ReflectionPad2dModule_Left(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20)) + +# ============================================================================== + +class ReflectionPad2dModuleRight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,11,0,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) +def ReflectionPad2dModule_Right(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20)) + +# ==============================================================================