Add dynamic pad asserts

pull/3758/head
giacs-epic 2024-11-15 10:54:26 +00:00
parent a2973b0f2f
commit e2b26e5b74
1 changed files with 24 additions and 5 deletions

View File

@ -403,21 +403,40 @@ public:
int64_t vDim = numDims - 2; int64_t vDim = numDims - 2;
Value hDimSize = inputShape[hDim]; Value hDimSize = inputShape[hDim];
Value vDimSize = inputShape[vDim]; Value vDimSize = inputShape[vDim];
Type indexType = rewriter.getIndexType();
auto leftPadAssertMsg = "Left padding too large";
auto rightPadAssertMsg = "Right padding too large";
auto topPadAssertMsg = "Top padding too large";
auto bottomPadAssertMsg = "Bottom padding too large";
auto addPadDynAssert = [&](int64_t pad, Value dimSize,
const llvm::Twine &msg) {
Value padValue = getConstant(rewriter, loc, pad, indexType);
Value pred = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, padValue, dimSize);
rewriter.create<cf::AssertOp>(loc, pred, rewriter.getStringAttr(msg));
};
if (inputType.getShape()[hDim] != ShapedType::kDynamic) { if (inputType.getShape()[hDim] != ShapedType::kDynamic) {
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
"Left padding too large"); leftPadAssertMsg);
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
"Right padding too large"); rightPadAssertMsg);
} else {
addPadDynAssert(getHPadArgument(LEFT), hDimSize, leftPadAssertMsg);
addPadDynAssert(getHPadArgument(RIGHT), hDimSize, rightPadAssertMsg);
} }
if (inputType.getShape()[vDim] != ShapedType::kDynamic) { if (inputType.getShape()[vDim] != ShapedType::kDynamic) {
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
"Top padding too large"); topPadAssertMsg);
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
"Bottom padding too large"); bottomPadAssertMsg);
} else {
addPadDynAssert(getVPadArgument(TOP), vDimSize, topPadAssertMsg);
addPadDynAssert(getVPadArgument(BOTTOM), vDimSize, bottomPadAssertMsg);
} }
Type indexType = rewriter.getIndexType();
Value zero = getConstant(rewriter, loc, 0, indexType); Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType); Value one = getConstant(rewriter, loc, 1, indexType);