mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add runtime.assert to check constraint when recomposing complex ops (#2281)
parent
50f5b658b6
commit
7f6b72aec8
|
@ -18,6 +18,21 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
// calculate: (a + b - 1) // b
|
||||
// a/b's type should be !torch.int
|
||||
Value getIntCeilDiv(PatternRewriter &rewriter, Location loc, Value a, Value b) {
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value dividend = rewriter.create<AtenAddIntOp>(loc, a, b);
|
||||
dividend = rewriter.create<AtenSubIntOp>(loc, dividend, cstOne);
|
||||
Value result = rewriter.create<AtenFloordivIntOp>(loc, dividend, b);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
|
||||
public:
|
||||
|
@ -147,14 +162,26 @@ public:
|
|||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenUnbindOp + PrimListUnpackOp to select.int
|
||||
auto unbind = dyn_cast<AtenUnbindIntOp>(op.getOperand().getDefiningOp());
|
||||
if (!unbind)
|
||||
auto unbindOp = dyn_cast<AtenUnbindIntOp>(op.getOperand().getDefiningOp());
|
||||
if (!unbindOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp");
|
||||
if (isListPotentiallyMutated(unbind.getResult()))
|
||||
if (isListPotentiallyMutated(unbindOp.getResult()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenUnbindIntOp result is potentially mutated");
|
||||
Value dim = unbind.getDim();
|
||||
Value input = unbind.getSelf();
|
||||
Location loc = op.getLoc();
|
||||
Value dim = unbindOp.getDim();
|
||||
Value input = unbindOp.getSelf();
|
||||
|
||||
// add runtime.assert to check unbind's dim size == numResults
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
Value cstNumResults = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
|
||||
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, totalSize, cstNumResults);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr("unbind's dim size should equal to "
|
||||
"prim.list_unpack's num results"));
|
||||
|
||||
SmallVector<Value> slices;
|
||||
for (size_t i = 0; i < op.getNumResults(); i++) {
|
||||
// rewrite to select.int op
|
||||
|
@ -166,8 +193,8 @@ public:
|
|||
slices.push_back(newSelect);
|
||||
}
|
||||
rewriter.replaceOp(op, slices);
|
||||
if (unbind.getResult().use_empty())
|
||||
rewriter.eraseOp(unbind);
|
||||
if (unbindOp.getResult().use_empty())
|
||||
rewriter.eraseOp(unbindOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -188,10 +215,21 @@ public:
|
|||
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
|
||||
if (index < 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value dim = unbind.getDim();
|
||||
Value input = unbind.getSelf();
|
||||
|
||||
// add runtime.assert to check: index
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
Value ltOrNot = rewriter.create<AtenLtIntOp>(loc, op.getIdx(), totalSize);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, ltOrNot,
|
||||
rewriter.getStringAttr("index should less than unbind's dim size"));
|
||||
|
||||
// rewrite to slice op
|
||||
auto resultTy = op.getResult().getType();
|
||||
Value newSelect = rewriter.create<AtenSelectIntOp>(loc, resultTy, input,
|
||||
|
@ -221,6 +259,9 @@ public:
|
|||
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
|
||||
if (index < 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int");
|
||||
|
||||
int64_t splitSize;
|
||||
if (!matchPattern(splitTensorOp.getSplitSize(),
|
||||
|
@ -230,6 +271,19 @@ public:
|
|||
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = splitTensorOp.getSelf();
|
||||
Value dim = splitTensorOp.getDim();
|
||||
|
||||
// add runtime.assert to check rank constraint: index < split_result_size
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
Value splitResultSize =
|
||||
getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize());
|
||||
Value ltOrNot =
|
||||
rewriter.create<AtenLtIntOp>(loc, op.getIdx(), splitResultSize);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, ltOrNot,
|
||||
rewriter.getStringAttr("index should less than split_result_size"));
|
||||
|
||||
Value step =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value start = rewriter.create<ConstantIntOp>(
|
||||
|
@ -237,8 +291,7 @@ public:
|
|||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize));
|
||||
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, op.getResult().getType(), splitTensorOp.getSelf(),
|
||||
splitTensorOp.getDim(), start, end, step);
|
||||
loc, op.getResult().getType(), input, dim, start, end, step);
|
||||
rewriter.replaceOp(op, sliceTensorOp);
|
||||
if (splitTensorOp.getResult().use_empty())
|
||||
rewriter.eraseOp(splitTensorOp);
|
||||
|
@ -269,8 +322,24 @@ public:
|
|||
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value step =
|
||||
Value input = splitTensorOp.getSelf();
|
||||
Value dim = splitTensorOp.getDim();
|
||||
|
||||
// add runtime.assert to check rank constraint
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
Value cstNumResults = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
// assert: numResults == floordiv(totalSize + splitSize - 1, splitSize)
|
||||
Value splitResultSize =
|
||||
getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize());
|
||||
Value eqOrNot =
|
||||
rewriter.create<AtenEqIntOp>(loc, splitResultSize, cstNumResults);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr("numResults should equal to floordiv(totalSize "
|
||||
"+ splitSize - 1, splitSize)"));
|
||||
|
||||
SmallVector<Value> slices;
|
||||
for (size_t i = 0; i < op.getNumResults(); i++) {
|
||||
|
@ -280,8 +349,7 @@ public:
|
|||
auto end = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((i + 1) * splitSize));
|
||||
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, resultTy, splitTensorOp.getSelf(), splitTensorOp.getDim(), start,
|
||||
end, step);
|
||||
loc, resultTy, input, dim, start, end, /*step=*/cstOne);
|
||||
slices.push_back(sliceTensorOp);
|
||||
}
|
||||
rewriter.replaceOp(op, slices);
|
||||
|
@ -298,25 +366,31 @@ public:
|
|||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps
|
||||
auto chunk = dyn_cast<AtenChunkOp>(op.getOperand().getDefiningOp());
|
||||
if (!chunk)
|
||||
auto chunkOp = dyn_cast<AtenChunkOp>(op.getOperand().getDefiningOp());
|
||||
if (!chunkOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp");
|
||||
if (isListPotentiallyMutated(chunk.getResult()))
|
||||
if (isListPotentiallyMutated(chunkOp.getResult()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenChunkOp result is potentially mutated");
|
||||
Value dim = chunk.getDim();
|
||||
Value input = chunk.getSelf();
|
||||
Value chunks = chunk.getChunks();
|
||||
Location loc = chunk.getLoc();
|
||||
Value dim = chunkOp.getDim();
|
||||
Value input = chunkOp.getSelf();
|
||||
Value chunks = chunkOp.getChunks();
|
||||
Location loc = chunkOp.getLoc();
|
||||
Value totalSize = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
||||
|
||||
// chunkSize = floordiv(totalSize + chunks - 1, chunks)
|
||||
Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks);
|
||||
|
||||
// add runtime.assert to check chunks == NumResults
|
||||
Value cstNumResults = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
|
||||
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, chunks, cstNumResults);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr(
|
||||
"chunks should equal to prim.list_unpack's num results"));
|
||||
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value dividend = rewriter.create<AtenAddIntOp>(loc, totalSize, chunks);
|
||||
dividend = rewriter.create<AtenSubIntOp>(loc, dividend, cstOne);
|
||||
Value chunkSize = rewriter.create<AtenFloordivIntOp>(loc, dividend, chunks);
|
||||
|
||||
SmallVector<Value> slices;
|
||||
for (size_t i = 0; i < op.getNumResults(); i++) {
|
||||
// rewrite to slice op with
|
||||
|
@ -334,13 +408,13 @@ public:
|
|||
end = rewriter.create<AtenMulIntOp>(loc, nextIdx, chunkSize);
|
||||
}
|
||||
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, resultTy, input, dim, start, end, cstOne);
|
||||
loc, resultTy, input, dim, start, end, /*step=*/cstOne);
|
||||
slices.push_back(sliceTensorOp);
|
||||
}
|
||||
rewriter.replaceOp(op, slices);
|
||||
// erase chunkOp if no user left
|
||||
if (chunk.getResult().use_empty())
|
||||
rewriter.eraseOp(chunk);
|
||||
if (chunkOp.getResult().use_empty())
|
||||
rewriter.eraseOp(chunkOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue