[Torch Dialect] add runtime.assert to check constraint when recomposing complex ops (#2281)

pull/2308/head
Yuanqiang Liu 2023-07-14 10:13:19 +08:00 committed by GitHub
parent 50f5b658b6
commit 7f6b72aec8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 101 additions and 27 deletions

View File

@ -18,6 +18,21 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
// calculate: (a + b - 1) // b
// a/b's type should be !torch.int
Value getIntCeilDiv(PatternRewriter &rewriter, Location loc, Value a, Value b) {
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value dividend = rewriter.create<AtenAddIntOp>(loc, a, b);
dividend = rewriter.create<AtenSubIntOp>(loc, dividend, cstOne);
Value result = rewriter.create<AtenFloordivIntOp>(loc, dividend, b);
return result;
}
} // namespace
namespace {
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
public:
@ -147,14 +162,26 @@ public:
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenUnbindOp + PrimListUnpackOp to select.int
auto unbind = dyn_cast<AtenUnbindIntOp>(op.getOperand().getDefiningOp());
if (!unbind)
auto unbindOp = dyn_cast<AtenUnbindIntOp>(op.getOperand().getDefiningOp());
if (!unbindOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp");
if (isListPotentiallyMutated(unbind.getResult()))
if (isListPotentiallyMutated(unbindOp.getResult()))
return rewriter.notifyMatchFailure(
op, "AtenUnbindIntOp result is potentially mutated");
Value dim = unbind.getDim();
Value input = unbind.getSelf();
Location loc = op.getLoc();
Value dim = unbindOp.getDim();
Value input = unbindOp.getSelf();
// add runtime.assert to check unbind's dim size == numResults
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value cstNumResults = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, totalSize, cstNumResults);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("unbind's dim size should equal to "
"prim.list_unpack's num results"));
SmallVector<Value> slices;
for (size_t i = 0; i < op.getNumResults(); i++) {
// rewrite to select.int op
@ -166,8 +193,8 @@ public:
slices.push_back(newSelect);
}
rewriter.replaceOp(op, slices);
if (unbind.getResult().use_empty())
rewriter.eraseOp(unbind);
if (unbindOp.getResult().use_empty())
rewriter.eraseOp(unbindOp);
return success();
}
};
@ -188,10 +215,21 @@ public:
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
if (index < 0)
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int");
Location loc = op.getLoc();
Value dim = unbind.getDim();
Value input = unbind.getSelf();
// add runtime.assert to check: index
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value ltOrNot = rewriter.create<AtenLtIntOp>(loc, op.getIdx(), totalSize);
rewriter.create<RuntimeAssertOp>(
loc, ltOrNot,
rewriter.getStringAttr("index should less than unbind's dim size"));
// rewrite to slice op
auto resultTy = op.getResult().getType();
Value newSelect = rewriter.create<AtenSelectIntOp>(loc, resultTy, input,
@ -221,6 +259,9 @@ public:
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
if (index < 0)
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int");
int64_t splitSize;
if (!matchPattern(splitTensorOp.getSplitSize(),
@ -230,6 +271,19 @@ public:
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
Location loc = op.getLoc();
Value input = splitTensorOp.getSelf();
Value dim = splitTensorOp.getDim();
// add runtime.assert to check rank constraint: index < split_result_size
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value splitResultSize =
getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize());
Value ltOrNot =
rewriter.create<AtenLtIntOp>(loc, op.getIdx(), splitResultSize);
rewriter.create<RuntimeAssertOp>(
loc, ltOrNot,
rewriter.getStringAttr("index should less than split_result_size"));
Value step =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value start = rewriter.create<ConstantIntOp>(
@ -237,8 +291,7 @@ public:
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize));
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
loc, op.getResult().getType(), splitTensorOp.getSelf(),
splitTensorOp.getDim(), start, end, step);
loc, op.getResult().getType(), input, dim, start, end, step);
rewriter.replaceOp(op, sliceTensorOp);
if (splitTensorOp.getResult().use_empty())
rewriter.eraseOp(splitTensorOp);
@ -269,8 +322,24 @@ public:
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
Location loc = op.getLoc();
Value step =
Value input = splitTensorOp.getSelf();
Value dim = splitTensorOp.getDim();
// add runtime.assert to check rank constraint
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value cstNumResults = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
// assert: numResults == floordiv(totalSize + splitSize - 1, splitSize)
Value splitResultSize =
getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize());
Value eqOrNot =
rewriter.create<AtenEqIntOp>(loc, splitResultSize, cstNumResults);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("numResults should equal to floordiv(totalSize "
"+ splitSize - 1, splitSize)"));
SmallVector<Value> slices;
for (size_t i = 0; i < op.getNumResults(); i++) {
@ -280,8 +349,7 @@ public:
auto end = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((i + 1) * splitSize));
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
loc, resultTy, splitTensorOp.getSelf(), splitTensorOp.getDim(), start,
end, step);
loc, resultTy, input, dim, start, end, /*step=*/cstOne);
slices.push_back(sliceTensorOp);
}
rewriter.replaceOp(op, slices);
@ -298,25 +366,31 @@ public:
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps
auto chunk = dyn_cast<AtenChunkOp>(op.getOperand().getDefiningOp());
if (!chunk)
auto chunkOp = dyn_cast<AtenChunkOp>(op.getOperand().getDefiningOp());
if (!chunkOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp");
if (isListPotentiallyMutated(chunk.getResult()))
if (isListPotentiallyMutated(chunkOp.getResult()))
return rewriter.notifyMatchFailure(
op, "AtenChunkOp result is potentially mutated");
Value dim = chunk.getDim();
Value input = chunk.getSelf();
Value chunks = chunk.getChunks();
Location loc = chunk.getLoc();
Value dim = chunkOp.getDim();
Value input = chunkOp.getSelf();
Value chunks = chunkOp.getChunks();
Location loc = chunkOp.getLoc();
Value totalSize = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
// chunkSize = floordiv(totalSize + chunks - 1, chunks)
Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks);
// add runtime.assert to check chunks == NumResults
Value cstNumResults = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, chunks, cstNumResults);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr(
"chunks should equal to prim.list_unpack's num results"));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value dividend = rewriter.create<AtenAddIntOp>(loc, totalSize, chunks);
dividend = rewriter.create<AtenSubIntOp>(loc, dividend, cstOne);
Value chunkSize = rewriter.create<AtenFloordivIntOp>(loc, dividend, chunks);
SmallVector<Value> slices;
for (size_t i = 0; i < op.getNumResults(); i++) {
// rewrite to slice op with
@ -334,13 +408,13 @@ public:
end = rewriter.create<AtenMulIntOp>(loc, nextIdx, chunkSize);
}
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
loc, resultTy, input, dim, start, end, cstOne);
loc, resultTy, input, dim, start, end, /*step=*/cstOne);
slices.push_back(sliceTensorOp);
}
rewriter.replaceOp(op, slices);
// erase chunkOp if no user left
if (chunk.getResult().use_empty())
rewriter.eraseOp(chunk);
if (chunkOp.getResult().use_empty())
rewriter.eraseOp(chunkOp);
return success();
}
};