mirror of https://github.com/llvm/torch-mlir
Add dynamic pad asserts
parent
a2973b0f2f
commit
e2b26e5b74
|
@ -403,21 +403,40 @@ public:
|
|||
int64_t vDim = numDims - 2;
|
||||
Value hDimSize = inputShape[hDim];
|
||||
Value vDimSize = inputShape[vDim];
|
||||
Type indexType = rewriter.getIndexType();
|
||||
|
||||
auto leftPadAssertMsg = "Left padding too large";
|
||||
auto rightPadAssertMsg = "Right padding too large";
|
||||
auto topPadAssertMsg = "Top padding too large";
|
||||
auto bottomPadAssertMsg = "Bottom padding too large";
|
||||
|
||||
auto addPadDynAssert = [&](int64_t pad, Value dimSize,
|
||||
const llvm::Twine &msg) {
|
||||
Value padValue = getConstant(rewriter, loc, pad, indexType);
|
||||
Value pred = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, padValue, dimSize);
|
||||
rewriter.create<cf::AssertOp>(loc, pred, rewriter.getStringAttr(msg));
|
||||
};
|
||||
|
||||
if (inputType.getShape()[hDim] != ShapedType::kDynamic) {
|
||||
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
|
||||
"Left padding too large");
|
||||
leftPadAssertMsg);
|
||||
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
|
||||
"Right padding too large");
|
||||
rightPadAssertMsg);
|
||||
} else {
|
||||
addPadDynAssert(getHPadArgument(LEFT), hDimSize, leftPadAssertMsg);
|
||||
addPadDynAssert(getHPadArgument(RIGHT), hDimSize, rightPadAssertMsg);
|
||||
}
|
||||
if (inputType.getShape()[vDim] != ShapedType::kDynamic) {
|
||||
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
|
||||
"Top padding too large");
|
||||
topPadAssertMsg);
|
||||
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
|
||||
"Bottom padding too large");
|
||||
bottomPadAssertMsg);
|
||||
} else {
|
||||
addPadDynAssert(getVPadArgument(TOP), vDimSize, topPadAssertMsg);
|
||||
addPadDynAssert(getVPadArgument(BOTTOM), vDimSize, bottomPadAssertMsg);
|
||||
}
|
||||
|
||||
Type indexType = rewriter.getIndexType();
|
||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
|
||||
|
|
Loading…
Reference in New Issue