mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] require dtype exists when decompose to aten.where.self (#2094)
* [Torch Dialect] require dtype exists when decompose to aten.where.self * updatepull/2133/head
parent
0302cf1d92
commit
e98f2ba04a
|
@ -129,20 +129,26 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||||||
// Helper to create a tensor filled with the given scalar. Scalar would be
|
// Helper to create a tensor filled with the given scalar. Scalar would be
|
||||||
// converted the to the element type of the given tensor type.
|
// converted the to the element type of the given tensor type.
|
||||||
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
||||||
Type resultType, Value scalar, Value sizeList) {
|
BaseTensorType resultType, Value scalar,
|
||||||
|
Value sizeList) {
|
||||||
|
assert(resultType.hasDtype() && "result must have dtype");
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||||
return rewriter.create<AtenFullOp>(
|
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
|
||||||
loc, resultType, sizeList, scalar, /*dtype=*/noneVal, /*layout=*/noneVal,
|
return rewriter.create<AtenFullOp>(loc, resultType, sizeList, scalar, dtype,
|
||||||
/*device=*/noneVal, /*memory_format=*/noneVal);
|
/*layout=*/noneVal,
|
||||||
|
/*device=*/noneVal,
|
||||||
|
/*memory_format=*/noneVal);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
||||||
// would be converted to the element type of the given `inputType`.
|
// would be converted to the element type of the given `inputType`.
|
||||||
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
||||||
BaseTensorType inputType, Value scalar) {
|
BaseTensorType inputType, Value scalar) {
|
||||||
|
assert(inputType.hasDtype() && "input must have dtype");
|
||||||
SmallVector<int64_t> sizes;
|
SmallVector<int64_t> sizes;
|
||||||
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
BaseTensorType rank0TensorTy =
|
||||||
ArrayRef(sizes), inputType.getOptionalDtype());
|
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
|
||||||
|
.cast<BaseTensorType>();
|
||||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||||
ValueRange{});
|
ValueRange{});
|
||||||
|
@ -895,6 +901,10 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
Value relu6 = getRelu6Results(rewriter, loc, op.getSelf());
|
Value relu6 = getRelu6Results(rewriter, loc, op.getSelf());
|
||||||
rewriter.replaceOp(op, relu6);
|
rewriter.replaceOp(op, relu6);
|
||||||
return success();
|
return success();
|
||||||
|
@ -944,6 +954,9 @@ public:
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value negativeSlope = op.getNegativeSlope();
|
Value negativeSlope = op.getNegativeSlope();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
Value constantZero =
|
Value constantZero =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
@ -978,6 +991,9 @@ public:
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value negativeSlope = op.getNegativeSlope();
|
Value negativeSlope = op.getNegativeSlope();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
bool selfIsResult = false;
|
bool selfIsResult = false;
|
||||||
if (!matchPattern(op.getSelfIsResult(),
|
if (!matchPattern(op.getSelfIsResult(),
|
||||||
|
@ -1372,6 +1388,9 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
||||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||||
|
@ -1391,6 +1410,9 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||||
op.getSelf(), otherTensor);
|
op.getSelf(), otherTensor);
|
||||||
|
@ -1409,6 +1431,9 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||||
selfTensor, op.getOther());
|
selfTensor, op.getOther());
|
||||||
|
@ -1427,6 +1452,9 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
Value mask = op.getMask();
|
Value mask = op.getMask();
|
||||||
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
||||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
|
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
|
||||||
|
@ -2236,6 +2264,10 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
// outputTensor = (input + 3) / 6.
|
// outputTensor = (input + 3) / 6.
|
||||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -2273,6 +2305,10 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
// result = min(maxVal, max(minVal, x))
|
// result = min(maxVal, max(minVal, x))
|
||||||
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal());
|
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal());
|
||||||
|
|
Loading…
Reference in New Issue