From e2b26e5b74babe14e4fe8796cec3edb9eaf05853 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Fri, 15 Nov 2024 10:54:26 +0000 Subject: [PATCH] Add dynamic pad asserts --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 05fd12f16..bd0215d30 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -403,21 +403,40 @@ public: int64_t vDim = numDims - 2; Value hDimSize = inputShape[hDim]; Value vDimSize = inputShape[vDim]; + Type indexType = rewriter.getIndexType(); + + auto leftPadAssertMsg = "Left padding too large"; + auto rightPadAssertMsg = "Right padding too large"; + auto topPadAssertMsg = "Top padding too large"; + auto bottomPadAssertMsg = "Bottom padding too large"; + + auto addPadDynAssert = [&](int64_t pad, Value dimSize, + const llvm::Twine &msg) { + Value padValue = getConstant(rewriter, loc, pad, indexType); + Value pred = rewriter.create( + loc, arith::CmpIPredicate::slt, padValue, dimSize); + rewriter.create(loc, pred, rewriter.getStringAttr(msg)); + }; if (inputType.getShape()[hDim] != ShapedType::kDynamic) { assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && - "Left padding too large"); + leftPadAssertMsg); assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && - "Right padding too large"); + rightPadAssertMsg); + } else { + addPadDynAssert(getHPadArgument(LEFT), hDimSize, leftPadAssertMsg); + addPadDynAssert(getHPadArgument(RIGHT), hDimSize, rightPadAssertMsg); } if (inputType.getShape()[vDim] != ShapedType::kDynamic) { assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && - "Top padding too large"); + topPadAssertMsg); assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && - "Bottom padding too large"); + bottomPadAssertMsg); + } else { + addPadDynAssert(getVPadArgument(TOP), vDimSize, topPadAssertMsg); + addPadDynAssert(getVPadArgument(BOTTOM), vDimSize, bottomPadAssertMsg); } - Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType);