From 9adad9bc407d92860b99f74b02da3a07b315d6b0 Mon Sep 17 00:00:00 2001 From: kumardeepakamd <123522031+kumardeepakamd@users.noreply.github.com> Date: Tue, 2 Jan 2024 11:05:11 -0800 Subject: [PATCH] Add support for reflection_pad1d (#2706) Adds a lowering to Linalg for reflection_pad1d. Based on ideas/code from draft PR https://github.com/llvm/torch-mlir/pull/2693. --------- Co-authored-by: Kumar Deepak --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 139 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 50 +++++++ .../build_tools/abstract_interp_lib_gen.py | 27 ++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 72 +++++++++ 6 files changed, 313 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 16eb5565b..23e65d75d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7869,6 +7869,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index dae387422..6534e8598 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -107,6 +107,143 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, return success(); } +// Example: +// input = tensor([[[0., 1., 2., 3.], +// [4., 5., 6., 7.]]]) +// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1 +// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], +// [7., 6., 5., 4., 5., 6., 7., 6.]]]) +// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension +// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension. +// The last dimension of the result tensor should be last dimension of input tensor + +// left padding size + right padding size. INitialize result tensor to all zeros +// b) Setup affine map to take slice from input tensor of size left padding starting from +// second column onwards as first column is reflection boundary +// c) Reflect the affine map to have resultant slice reflected +// d) Take the slice and write from begining in result tensor +// e) write the original tensor next into result tensor +// f) Setup affine map to take slice from input tensor of right padding size ending +// at second last column as last column is reflection boundary for right padding +// g) Reflect the affine map to have resultant slice reflected +// h) Take the slice and write from left padding size + orignal tensor last dim size +// into result tensor +// Uses the ideas/code used for AtenReflectionPad2dOp +namespace { +class ConvertAtenReflectionPad1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only constant int padding range is supported"); + + MLIRContext *context = rewriter.getContext(); + Location loc = op.getLoc(); + + // Lambda Unitility Functions + // Create an Integer expression of x + y + auto createIAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Create an integer expression of x - y + auto createISub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2}; + + Value input = adaptor.getSelf(); + Type indexType = rewriter.getIndexType(); + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + auto inputType = llvm::cast(input.getType()); + auto outputType = llvm::cast(getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + int64_t lastDim = numDims - 1; + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4 + + Value tileWidth[3], extractOffset[3], insertOffset[3]; + + tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); + tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); + tileWidth[PAD_CENTER] = lastDimSize; + + extractOffset[PAD_LEFT] = one; + // for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right + // lasDimSize - (tileWidth[PAD_RIGHT] + one) + extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); + extractOffset[PAD_CENTER] = zero; + + insertOffset[PAD_LEFT] = zero; + insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); + insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; + + + SmallVector resultShape{inputShape}; + // Result's last dimension will have shape lastDimSize + left padding size + right padding size + resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); + + // Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor + // for which the corresponding dimension has a statically known size + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { + AffineExpr d = map.getResult(i); + return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3 + }; + + SmallVector iteratorTypes{numDims, utils::IteratorType::parallel}; + auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); + SmallVector allOneStrides(numDims, one); + + auto addTileToResult = [&](PadLocation padPosition) { + // Create the tile by extracting a slice from the input tensor. + SmallVector extractShape{inputShape}; + extractShape[lastDim] = tileWidth[padPosition]; + SmallVector extractOffsets(numDims, zero); + extractOffsets[lastDim] = extractOffset[padPosition]; + Value tile = rewriter.create( + loc, input, extractOffsets, extractShape, allOneStrides); + + + auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); + // Setup the affine map function to resverse the tile along the horizontal for left and right slices + if(padPosition < PAD_CENTER) { + inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); + // Take reflected slice as per inputMap + tile = rewriter.create(loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }).getResult(0); + } + // Insert the tile in the resultTensor + SmallVector insertOffsets(numDims, zero); + insertOffsets[lastDim] = insertOffset[padPosition]; + resultTensor = rewriter.create(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + }; + + if(padInts[PAD_LEFT] > 0) + addTileToResult(PAD_LEFT); + if(padInts[PAD_RIGHT] > 0) + addTileToResult(PAD_RIGHT); + addTileToResult(PAD_CENTER); + + rewriter.replaceOpWithNewOp(op, outputType, resultTensor); + return success(); + } +}; +} + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -1413,6 +1550,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 25e83899b..4adf55556 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8331,6 +8331,41 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8952,6 +8987,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" +" %int2 = torch.constant.int 2\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2e6094a6f..48949c318 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1271,6 +1271,21 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) +#Padding size must be smaller than the size of the last dimension +@check_shape_function([ErrorInvocation(TensorOfShape(1, 2, 4), padding=[4,1]), + Invocation(TensorOfShape(1, 2, 4), padding=[3,3]), + ErrorInvocation(TensorOfShape(1, 2, 4), padding=[1,4]), + ErrorInvocation(TensorOfShape(1, 4), padding=[4,1]), + Invocation(TensorOfShape(1, 4), padding=[3,3]), + ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) +def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + hdim = self[-1] + padding_left = padding[0] + padding_right = padding[1] + assert padding_left < hdim and padding_right < hdim + return pad_shape_fn(self, padding) + # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" @@ -1804,6 +1819,18 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i self_rank, self_dtype = self_rank_dtype return self_dtype + +@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), + Invocation(TensorOfShape(2, 3, 4), padding=[2,1]), + Invocation(TensorOfShape(5, 5, 4), padding=[1,2]), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1])]) +def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert len(padding) == 2, 'padding size expected to be 2' + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fb458f6a5..4d5b65c1d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -541,6 +541,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 971aa1efc..20bc293e7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -552,8 +552,80 @@ def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ReflectionPad1dModule3dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (3,1)) +@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput()) +def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(1,2,4)) + + +class ReflectionPad1dModule2dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (3,2)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInput()) +def ReflectionPad1dModule2dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(2,4)) + +class ReflectionPad1dModule3dInputLeft(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (2,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInputLeft()) +def ReflectionPad1dModule3dInput_Left(module, tu: TestUtils): + module.forward(tu.rand(1,4,5)) + +class ReflectionPad1dModule2dInputRight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 6], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (0,3)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInputRight()) +def ReflectionPad1dModule2dInput_Right(module, tu: TestUtils): + module.forward(tu.rand(3,6)) + +# ============================================================================== class TransposeIntModule(torch.nn.Module): def __init__(self):