mirror of https://github.com/llvm/torch-mlir
Add dynamic pad asserts
parent
a2973b0f2f
commit
e2b26e5b74
|
@ -403,21 +403,40 @@ public:
|
||||||
int64_t vDim = numDims - 2;
|
int64_t vDim = numDims - 2;
|
||||||
Value hDimSize = inputShape[hDim];
|
Value hDimSize = inputShape[hDim];
|
||||||
Value vDimSize = inputShape[vDim];
|
Value vDimSize = inputShape[vDim];
|
||||||
|
Type indexType = rewriter.getIndexType();
|
||||||
|
|
||||||
|
auto leftPadAssertMsg = "Left padding too large";
|
||||||
|
auto rightPadAssertMsg = "Right padding too large";
|
||||||
|
auto topPadAssertMsg = "Top padding too large";
|
||||||
|
auto bottomPadAssertMsg = "Bottom padding too large";
|
||||||
|
|
||||||
|
auto addPadDynAssert = [&](int64_t pad, Value dimSize,
|
||||||
|
const llvm::Twine &msg) {
|
||||||
|
Value padValue = getConstant(rewriter, loc, pad, indexType);
|
||||||
|
Value pred = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, padValue, dimSize);
|
||||||
|
rewriter.create<cf::AssertOp>(loc, pred, rewriter.getStringAttr(msg));
|
||||||
|
};
|
||||||
|
|
||||||
if (inputType.getShape()[hDim] != ShapedType::kDynamic) {
|
if (inputType.getShape()[hDim] != ShapedType::kDynamic) {
|
||||||
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
|
assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] &&
|
||||||
"Left padding too large");
|
leftPadAssertMsg);
|
||||||
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
|
assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] &&
|
||||||
"Right padding too large");
|
rightPadAssertMsg);
|
||||||
|
} else {
|
||||||
|
addPadDynAssert(getHPadArgument(LEFT), hDimSize, leftPadAssertMsg);
|
||||||
|
addPadDynAssert(getHPadArgument(RIGHT), hDimSize, rightPadAssertMsg);
|
||||||
}
|
}
|
||||||
if (inputType.getShape()[vDim] != ShapedType::kDynamic) {
|
if (inputType.getShape()[vDim] != ShapedType::kDynamic) {
|
||||||
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
|
assert(getVPadArgument(TOP) < inputType.getShape()[vDim] &&
|
||||||
"Top padding too large");
|
topPadAssertMsg);
|
||||||
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
|
assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] &&
|
||||||
"Bottom padding too large");
|
bottomPadAssertMsg);
|
||||||
|
} else {
|
||||||
|
addPadDynAssert(getVPadArgument(TOP), vDimSize, topPadAssertMsg);
|
||||||
|
addPadDynAssert(getVPadArgument(BOTTOM), vDimSize, bottomPadAssertMsg);
|
||||||
}
|
}
|
||||||
|
|
||||||
Type indexType = rewriter.getIndexType();
|
|
||||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue