//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { class RecomposeSliceCopy_ : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCopy_Op op, PatternRewriter &rewriter) const override { if (!op.getSelf().getDefiningOp() || !isa(op.getSelf().getDefiningOp())) return failure(); auto sliceOp = cast(op.getSelf().getDefiningOp()); // Get indices int64_t dim; if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim))) return failure(); int64_t end; if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end))) return failure(); Value newEnd = sliceOp.getEnd(); if (end < 0) { Value dimSize = rewriter.create( op.getLoc(), sliceOp.getSelf(), sliceOp.getDim()); newEnd = rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); } Value noneVal = rewriter.create(op.getLoc()); Value falseVal = rewriter.create(op.getLoc(), false); // Create IndexPut_Op BaseTensorType tensorType = op->getResultTypes()[0].cast(); Value range = rewriter.create( op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(), /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); SmallVector indicesVector; for (auto i = 0; i < dim - 1; i++) indicesVector.push_back(noneVal); indicesVector.push_back(range); Value indices = rewriter.create( op.getLoc(), Torch::ListType::get(op->getContext(), Torch::OptionalType::get(tensorType)), indicesVector); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(), /*accumulate=*/falseVal, /*unsafe=*/falseVal); return success(); } }; class RecomposeSelectFill_ : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFill_TensorOp op, PatternRewriter &rewriter) const override { if (!op.getSelf().getDefiningOp() || !isa(op.getSelf().getDefiningOp())) return failure(); auto selectOp = cast(op.getSelf().getDefiningOp()); // Get indices int64_t dim; if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim))) return failure(); Value noneVal = rewriter.create(op.getLoc()); Value falseVal = rewriter.create(op.getLoc(), false); // Create IndexPut_Op // Convert indexNum to indexTensor for the selectOp BaseTensorType selectOutTy = selectOp.getType().template cast(); SmallVector empty; auto dtype = getTypeForTorchType(selectOp.getContext(), selectOp.getIndex().getType()); Type emptyTensorType = selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value indexTensor = rewriter.create( selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); // Create indicesVector for IndexPut_Op by TorchNone and indexTensor BaseTensorType tensorType = op->getResultTypes()[0].cast(); SmallVector indicesVector(dim - 1, noneVal); indicesVector.push_back(indexTensor); Value indices = rewriter.create( op.getLoc(), Torch::ListType::get(op->getContext(), Torch::OptionalType::get(tensorType)), indicesVector); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(), /*accumulate=*/falseVal, /*unsafe=*/falseVal); return success(); } }; } // namespace namespace { class RecomposeComplexOpsPass : public RecomposeComplexOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); // pattern.add calls go here patterns.add(context); patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createRecomposeComplexOpsPass() { return std::make_unique(); }