mirror of https://github.com/llvm/torch-mlir
parent
8dcd0b2e76
commit
4df1d8ae2f
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -71,6 +72,55 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenFill_TensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (!op.getSelf().getDefiningOp() ||
|
||||||
|
!isa<AtenSelectIntOp>(op.getSelf().getDefiningOp()))
|
||||||
|
return failure();
|
||||||
|
auto selectOp = cast<AtenSelectIntOp>(op.getSelf().getDefiningOp());
|
||||||
|
|
||||||
|
// Get indices
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
|
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
||||||
|
|
||||||
|
// Create IndexPut_Op
|
||||||
|
// Convert indexNum to indexTensor for the selectOp
|
||||||
|
BaseTensorType selectOutTy =
|
||||||
|
selectOp.getType().template cast<BaseTensorType>();
|
||||||
|
SmallVector<int64_t> empty;
|
||||||
|
auto dtype = getTypeForTorchType(selectOp.getContext(),
|
||||||
|
selectOp.getIndex().getType());
|
||||||
|
Type emptyTensorType =
|
||||||
|
selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
||||||
|
Value indexTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
||||||
|
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
|
||||||
|
|
||||||
|
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
|
||||||
|
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
|
||||||
|
SmallVector<Value> indicesVector(dim - 1, noneVal);
|
||||||
|
indicesVector.push_back(indexTensor);
|
||||||
|
|
||||||
|
Value indices = rewriter.create<PrimListConstructOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
Torch::ListType::get(op->getContext(),
|
||||||
|
Torch::OptionalType::get(tensorType)),
|
||||||
|
indicesVector);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
|
||||||
|
op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(),
|
||||||
|
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -83,6 +133,7 @@ public:
|
||||||
|
|
||||||
// pattern.add calls go here
|
// pattern.add calls go here
|
||||||
patterns.add<RecomposeSliceCopy_>(context);
|
patterns.add<RecomposeSliceCopy_>(context);
|
||||||
|
patterns.add<RecomposeSelectFill_>(context);
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
|
|
Loading…
Reference in New Issue