//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { // calculate: (a + b - 1) // b // a/b's type should be !torch.int Value getIntCeilDiv(PatternRewriter &rewriter, Location loc, Value a, Value b) { Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value dividend = rewriter.create(loc, a, b); dividend = rewriter.create(loc, dividend, cstOne); Value result = rewriter.create(loc, dividend, b); return result; } } // namespace namespace { class RecomposeSliceCopy_ : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCopy_Op op, PatternRewriter &rewriter) const override { // This pattern replaces the in-place mutation of a slice of a tensor with // an `index_put` op. Since the slice of the tensor can have a different // shape than the full tensor, this pattern requires the `copy_` op to not // have users to avoid mismached types. This restriction can be removed by // inserting another slice after the `index_put` that creates a tensor of // the same shape as the operand to `copy_`. if (!op.use_empty()) return rewriter.notifyMatchFailure( op, "`AtenCopy_Op` must not have any users"); if (!op.getSelf().getDefiningOp() || !isa(op.getSelf().getDefiningOp())) return rewriter.notifyMatchFailure( op, "defining op is not `AtenSliceTensorOp`"); auto sliceOp = cast(op.getSelf().getDefiningOp()); // Get indices int64_t dim; if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim))) return failure(); int64_t end; if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end))) return failure(); Value newStart = sliceOp.getStart(); Value newEnd = sliceOp.getEnd(); Value dimSize = rewriter.create( op.getLoc(), sliceOp.getSelf(), sliceOp.getDim()); if (end < 0) { newEnd = rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); } newStart = rewriter.create(op.getLoc(), newStart, dimSize); newEnd = rewriter.create(op.getLoc(), newEnd, dimSize); Value noneVal = rewriter.create(op.getLoc()); Value falseVal = rewriter.create(op.getLoc(), false); // Create IndexPut_Op BaseTensorType tensorType = cast(op.getType()); Type rangeType = tensorType.getWithSizesAndDtype( {kUnknownSize}, tensorType.getOptionalDtype()); Value range = rewriter.create( op.getLoc(), rangeType, newStart, newEnd, sliceOp.getStep(), /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); SmallVector indicesVector; for (auto i = 0; i < dim; i++) indicesVector.push_back(noneVal); indicesVector.push_back(range); Type indicesType = tensorType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Value indices = rewriter.create( op.getLoc(), Torch::ListType::get(op->getContext(), Torch::OptionalType::get(indicesType)), indicesVector); Value sliceOpInput = sliceOp.getSelf(); rewriter.replaceOpWithNewOp( op, sliceOpInput.getType(), sliceOpInput, indices, op.getSrc(), /*accumulate=*/falseVal, /*unsafe=*/falseVal); if (sliceOp->use_empty()) rewriter.eraseOp(sliceOp); return success(); } }; class RecomposeSelectFill_ : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFill_TensorOp op, PatternRewriter &rewriter) const override { if (!op.getSelf().getDefiningOp() || !isa(op.getSelf().getDefiningOp())) return failure(); auto selectOp = cast(op.getSelf().getDefiningOp()); // Get indices int64_t dim; if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim))) return failure(); Value noneVal = rewriter.create(op.getLoc()); Value falseVal = rewriter.create(op.getLoc(), false); // Create IndexPut_Op // Convert indexNum to indexTensor for the selectOp BaseTensorType selectOutTy = cast(selectOp.getType()); SmallVector empty; auto dtype = getTypeForTorchType(selectOp.getContext(), selectOp.getIndex().getType()); Type emptyTensorType = selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value indexTensor = rewriter.create( selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); // Create indicesVector for IndexPut_Op by TorchNone and indexTensor BaseTensorType tensorType = cast(op->getResultTypes()[0]); SmallVector indicesVector(dim, noneVal); indicesVector.push_back(indexTensor); Value indices = rewriter.create( op.getLoc(), Torch::ListType::get(op->getContext(), Torch::OptionalType::get(tensorType)), indicesVector); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(), /*accumulate=*/falseVal, /*unsafe=*/falseVal); return success(); } }; class RecomposeUnbindListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindOp + PrimListUnpackOp to select.int auto unbindOp = dyn_cast(op.getOperand().getDefiningOp()); if (!unbindOp) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbindOp.getResult())) return rewriter.notifyMatchFailure( op, "AtenUnbindIntOp result is potentially mutated"); Location loc = op.getLoc(); Value dim = unbindOp.getDim(); Value input = unbindOp.getSelf(); // add runtime.assert to check unbind's dim size == numResults Value totalSize = rewriter.create(loc, input, dim); Value cstNumResults = rewriter.create( loc, rewriter.getI64IntegerAttr(op.getNumResults())); Value eqOrNot = rewriter.create(loc, totalSize, cstNumResults); rewriter.create( loc, eqOrNot, rewriter.getStringAttr("unbind's dim size should equal to " "prim.list_unpack's num results")); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { // rewrite to select.int op auto resultTy = op.getResult(i).getType(); auto index = rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(i)); auto newSelect = rewriter.create(op->getLoc(), resultTy, input, dim, index); slices.push_back(newSelect); } rewriter.replaceOp(op, slices); if (unbindOp.getResult().use_empty()) rewriter.eraseOp(unbindOp); return success(); } }; class RecomposeUnbindGetItem : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten__Getitem__TOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindIntOp + __getitem__t to select.int auto unbind = dyn_cast(op.getList().getDefiningOp()); if (!unbind) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbind.getResult())) return rewriter.notifyMatchFailure( op, "AtenUnbindIntOp result is potentially mutated"); int64_t index; if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) return rewriter.notifyMatchFailure( op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); if (index < 0) return rewriter.notifyMatchFailure( op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int"); Location loc = op.getLoc(); Value dim = unbind.getDim(); Value input = unbind.getSelf(); // add runtime.assert to check: index Value totalSize = rewriter.create(loc, input, dim); Value ltOrNot = rewriter.create(loc, op.getIdx(), totalSize); rewriter.create( loc, ltOrNot, rewriter.getStringAttr("index should less than unbind's dim size")); // rewrite to slice op auto resultTy = op.getResult().getType(); Value newSelect = rewriter.create(loc, resultTy, input, dim, op.getIdx()); rewriter.replaceOp(op, newSelect); if (unbind.getResult().use_empty()) rewriter.eraseOp(unbind); return success(); } }; class RecomposeSplitTensorGetItemOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten__Getitem__TOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp auto splitTensorOp = dyn_cast(op.getList().getDefiningOp()); if (!splitTensorOp) return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) return rewriter.notifyMatchFailure( op, "SplitTensorOp result is potentially mutated"); int64_t index; if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) return rewriter.notifyMatchFailure( op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); if (index < 0) return rewriter.notifyMatchFailure( op, "Expected `idx` of `Aten__Getitem__TOp` to be a positive int"); int64_t splitSize; if (!matchPattern(splitTensorOp.getSplitSize(), m_TorchConstantInt(&splitSize))) return rewriter.notifyMatchFailure( op, "Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int"); Location loc = op.getLoc(); Value input = splitTensorOp.getSelf(); Value dim = splitTensorOp.getDim(); // add runtime.assert to check rank constraint: index < split_result_size Value totalSize = rewriter.create(loc, input, dim); Value splitResultSize = getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize()); Value ltOrNot = rewriter.create(loc, op.getIdx(), splitResultSize); rewriter.create( loc, ltOrNot, rewriter.getStringAttr("index should less than split_result_size")); Value step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value start = rewriter.create( loc, rewriter.getI64IntegerAttr(index * splitSize)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize)); Value sliceTensorOp = rewriter.create( loc, op.getResult().getType(), input, dim, start, end, step); rewriter.replaceOp(op, sliceTensorOp); if (splitTensorOp.getResult().use_empty()) rewriter.eraseOp(splitTensorOp); return success(); } }; class RecomposeSplitTensorListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps auto splitTensorOp = dyn_cast(op.getOperand().getDefiningOp()); if (!splitTensorOp) return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) return rewriter.notifyMatchFailure( op, "SplitTensorOp result is potentially mutated"); int64_t splitSize; if (!matchPattern(splitTensorOp.getSplitSize(), m_TorchConstantInt(&splitSize))) return rewriter.notifyMatchFailure( op, "Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int"); Location loc = op.getLoc(); Value input = splitTensorOp.getSelf(); Value dim = splitTensorOp.getDim(); // add runtime.assert to check rank constraint Value totalSize = rewriter.create(loc, input, dim); Value cstNumResults = rewriter.create( loc, rewriter.getI64IntegerAttr(op.getNumResults())); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); // assert: numResults == floordiv(totalSize + splitSize - 1, splitSize) Value splitResultSize = getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize()); Value eqOrNot = rewriter.create(loc, splitResultSize, cstNumResults); rewriter.create( loc, eqOrNot, rewriter.getStringAttr("numResults should equal to floordiv(totalSize " "+ splitSize - 1, splitSize)")); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { auto resultTy = op.getResult(i).getType(); auto start = rewriter.create( loc, rewriter.getI64IntegerAttr(i * splitSize)); auto end = rewriter.create( loc, rewriter.getI64IntegerAttr((i + 1) * splitSize)); Value sliceTensorOp = rewriter.create( loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); // erase splitTensorOp if no user left if (splitTensorOp.getResult().use_empty()) rewriter.eraseOp(splitTensorOp); return success(); } }; class RecomposeSplitWithSizesListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps auto splitOp = dyn_cast(op.getOperand().getDefiningOp()); if (!splitOp) { return rewriter.notifyMatchFailure(op, "Input is not AtenSplitWithSizesOp"); } if (isListPotentiallyMutated(splitOp.getResult())) { return rewriter.notifyMatchFailure( op, "splitWithSizesOp result is potentially mutated"); } if (isListPotentiallyMutated(splitOp.getSplitSizes())) { return rewriter.notifyMatchFailure( op, "splitWithSizesOp's split_sizes is potentially mutated"); } auto splitSizesConstruct = splitOp.getSplitSizes().getDefiningOp(); if (!splitSizesConstruct) { return rewriter.notifyMatchFailure( op, "split_sizes is not from PrimListConstructOp"); } int64_t sumSplitSize = 0; SmallVector splitSizes; for (auto operand : splitSizesConstruct.getOperands()) { int64_t value = -1; // TODO: support when split_sizes are not constant int if (!matchPattern(operand, m_TorchConstantInt(&value))) { return rewriter.notifyMatchFailure( op, "one of split_sizes is not constant int"); } if (value < 0) { return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0"); } sumSplitSize += value; splitSizes.push_back(value); } if (splitSizes.size() != op.getNumResults()) { return rewriter.notifyMatchFailure( op, "split_sizes must be same as splitOp result size"); } Location loc = op.getLoc(); Value input = splitOp.getSelf(); Value dim = splitOp.getDim(); // add runtime.assert to check rank constraint Value totalSize = rewriter.create(loc, input, dim); Value cstSumSplitSize = rewriter.create( loc, rewriter.getI64IntegerAttr(sumSplitSize)); Value eqOrNot = rewriter.create(loc, totalSize, cstSumSplitSize); rewriter.create( loc, eqOrNot, rewriter.getStringAttr("split dim must be sum of split_sizes")); // calculate slice op's lower bound and up bound SmallVector boundaryOfSliceOp(splitSizes.size() + 1, 0); for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) { boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1]; } SmallVector slices; Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); for (size_t i = 0; i < op.getNumResults(); i++) { auto resultTy = op.getResult(i).getType(); auto start = rewriter.create( loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i])); auto end = rewriter.create( loc, rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1]))); Value sliceTensorOp = rewriter.create( loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); // erase splitOp if no user left if (splitOp.getResult().use_empty()) rewriter.eraseOp(splitOp); return success(); } }; class RecomposeChunkListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps auto chunkOp = dyn_cast(op.getOperand().getDefiningOp()); if (!chunkOp) return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp"); if (isListPotentiallyMutated(chunkOp.getResult())) return rewriter.notifyMatchFailure( op, "AtenChunkOp result is potentially mutated"); Value dim = chunkOp.getDim(); Value input = chunkOp.getSelf(); Value chunks = chunkOp.getChunks(); Location loc = chunkOp.getLoc(); Value totalSize = rewriter.create(loc, input, dim); // chunkSize = floordiv(totalSize + chunks - 1, chunks) Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks); // add runtime.assert to check chunks == NumResults Value cstNumResults = rewriter.create( loc, rewriter.getI64IntegerAttr(op.getNumResults())); Value eqOrNot = rewriter.create(loc, chunks, cstNumResults); rewriter.create( loc, eqOrNot, rewriter.getStringAttr( "chunks should equal to prim.list_unpack's num results")); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { // rewrite to slice op with // start = chunkSize * i, // end = lastIndex ? totalSize : chunkSize * (i+1) auto resultTy = op.getResult(i).getType(); auto index = rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(i)); auto start = rewriter.create(loc, index, chunkSize); Value end; if (i == op.getNumResults() - 1) { end = totalSize; } else { auto nextIdx = rewriter.create(loc, index, cstOne); end = rewriter.create(loc, nextIdx, chunkSize); } Value sliceTensorOp = rewriter.create( loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); // erase chunkOp if no user left if (chunkOp.getResult().use_empty()) rewriter.eraseOp(chunkOp); return success(); } }; } // namespace namespace { class RecomposeComplexOpsPass : public RecomposeComplexOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); // pattern.add calls go here patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createRecomposeComplexOpsPass() { return std::make_unique(); }