mirror of https://github.com/llvm/torch-mlir
parent
8dcd0b2e76
commit
4df1d8ae2f
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -71,6 +72,55 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFill_TensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getSelf().getDefiningOp() ||
|
||||
!isa<AtenSelectIntOp>(op.getSelf().getDefiningOp()))
|
||||
return failure();
|
||||
auto selectOp = cast<AtenSelectIntOp>(op.getSelf().getDefiningOp());
|
||||
|
||||
// Get indices
|
||||
int64_t dim;
|
||||
if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim)))
|
||||
return failure();
|
||||
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
||||
|
||||
// Create IndexPut_Op
|
||||
// Convert indexNum to indexTensor for the selectOp
|
||||
BaseTensorType selectOutTy =
|
||||
selectOp.getType().template cast<BaseTensorType>();
|
||||
SmallVector<int64_t> empty;
|
||||
auto dtype = getTypeForTorchType(selectOp.getContext(),
|
||||
selectOp.getIndex().getType());
|
||||
Type emptyTensorType =
|
||||
selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
||||
Value indexTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
||||
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
|
||||
|
||||
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
|
||||
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
|
||||
SmallVector<Value> indicesVector(dim - 1, noneVal);
|
||||
indicesVector.push_back(indexTensor);
|
||||
|
||||
Value indices = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(),
|
||||
Torch::ListType::get(op->getContext(),
|
||||
Torch::OptionalType::get(tensorType)),
|
||||
indicesVector);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
|
||||
op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(),
|
||||
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -83,6 +133,7 @@ public:
|
|||
|
||||
// pattern.add calls go here
|
||||
patterns.add<RecomposeSliceCopy_>(context);
|
||||
patterns.add<RecomposeSelectFill_>(context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
Loading…
Reference in New Issue