diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 1e3875a3a..dbddcc312 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -71,6 +72,55 @@ public: return success(); } }; + +class RecomposeSelectFill_ : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFill_TensorOp op, + PatternRewriter &rewriter) const override { + if (!op.getSelf().getDefiningOp() || + !isa(op.getSelf().getDefiningOp())) + return failure(); + auto selectOp = cast(op.getSelf().getDefiningOp()); + + // Get indices + int64_t dim; + if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim))) + return failure(); + + Value noneVal = rewriter.create(op.getLoc()); + Value falseVal = rewriter.create(op.getLoc(), false); + + // Create IndexPut_Op + // Convert indexNum to indexTensor for the selectOp + BaseTensorType selectOutTy = + selectOp.getType().template cast(); + SmallVector empty; + auto dtype = getTypeForTorchType(selectOp.getContext(), + selectOp.getIndex().getType()); + Type emptyTensorType = + selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); + Value indexTensor = rewriter.create( + selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); + + // Create indicesVector for IndexPut_Op by TorchNone and indexTensor + BaseTensorType tensorType = op->getResultTypes()[0].cast(); + SmallVector indicesVector(dim - 1, noneVal); + indicesVector.push_back(indexTensor); + + Value indices = rewriter.create( + op.getLoc(), + Torch::ListType::get(op->getContext(), + Torch::OptionalType::get(tensorType)), + indicesVector); + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(), + /*accumulate=*/falseVal, /*unsafe=*/falseVal); + + return success(); + } +}; } // namespace namespace { @@ -83,6 +133,7 @@ public: // pattern.add calls go here patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true;