[MLIR] Fold aten select and fill_ pattern (#2000)

pull/2003/head snapshot-20230407.801
Chi_Liu 2023-04-06 21:16:51 -07:00 committed by GitHub
parent 8dcd0b2e76
commit 4df1d8ae2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 51 additions and 0 deletions

View File

@ -12,6 +12,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
@ -71,6 +72,55 @@ public:
return success();
}
};
class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFill_TensorOp op,
PatternRewriter &rewriter) const override {
if (!op.getSelf().getDefiningOp() ||
!isa<AtenSelectIntOp>(op.getSelf().getDefiningOp()))
return failure();
auto selectOp = cast<AtenSelectIntOp>(op.getSelf().getDefiningOp());
// Get indices
int64_t dim;
if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim)))
return failure();
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
// Create IndexPut_Op
// Convert indexNum to indexTensor for the selectOp
BaseTensorType selectOutTy =
selectOp.getType().template cast<BaseTensorType>();
SmallVector<int64_t> empty;
auto dtype = getTypeForTorchType(selectOp.getContext(),
selectOp.getIndex().getType());
Type emptyTensorType =
selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
Value indexTensor = rewriter.create<PrimNumToTensorScalarOp>(
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
SmallVector<Value> indicesVector(dim - 1, noneVal);
indicesVector.push_back(indexTensor);
Value indices = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(op->getContext(),
Torch::OptionalType::get(tensorType)),
indicesVector);
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
return success();
}
};
} // namespace
namespace {
@ -83,6 +133,7 @@ public:
// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);
patterns.add<RecomposeSelectFill_>(context);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;