mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] require dtype exists when decompose to aten.where.self (#2094)
* [Torch Dialect] require dtype exists when decompose to aten.where.self * updatepull/2133/head
parent
0302cf1d92
commit
e98f2ba04a
|
@ -129,20 +129,26 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
|||
// Helper to create a tensor filled with the given scalar. Scalar would be
|
||||
// converted the to the element type of the given tensor type.
|
||||
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
||||
Type resultType, Value scalar, Value sizeList) {
|
||||
BaseTensorType resultType, Value scalar,
|
||||
Value sizeList) {
|
||||
assert(resultType.hasDtype() && "result must have dtype");
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
return rewriter.create<AtenFullOp>(
|
||||
loc, resultType, sizeList, scalar, /*dtype=*/noneVal, /*layout=*/noneVal,
|
||||
/*device=*/noneVal, /*memory_format=*/noneVal);
|
||||
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
|
||||
return rewriter.create<AtenFullOp>(loc, resultType, sizeList, scalar, dtype,
|
||||
/*layout=*/noneVal,
|
||||
/*device=*/noneVal,
|
||||
/*memory_format=*/noneVal);
|
||||
}
|
||||
|
||||
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
||||
// would be converted to the element type of the given `inputType`.
|
||||
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
||||
BaseTensorType inputType, Value scalar) {
|
||||
assert(inputType.hasDtype() && "input must have dtype");
|
||||
SmallVector<int64_t> sizes;
|
||||
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
||||
ArrayRef(sizes), inputType.getOptionalDtype());
|
||||
BaseTensorType rank0TensorTy =
|
||||
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
|
||||
.cast<BaseTensorType>();
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||
ValueRange{});
|
||||
|
@ -895,6 +901,10 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
Value relu6 = getRelu6Results(rewriter, loc, op.getSelf());
|
||||
rewriter.replaceOp(op, relu6);
|
||||
return success();
|
||||
|
@ -944,6 +954,9 @@ public:
|
|||
Value input = op.getSelf();
|
||||
Value negativeSlope = op.getNegativeSlope();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
Value constantZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
|
@ -978,6 +991,9 @@ public:
|
|||
Value input = op.getSelf();
|
||||
Value negativeSlope = op.getNegativeSlope();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
bool selfIsResult = false;
|
||||
if (!matchPattern(op.getSelfIsResult(),
|
||||
|
@ -1372,6 +1388,9 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||
|
@ -1391,6 +1410,9 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||
op.getSelf(), otherTensor);
|
||||
|
@ -1409,6 +1431,9 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
|
||||
selfTensor, op.getOther());
|
||||
|
@ -1427,6 +1452,9 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
Value mask = op.getMask();
|
||||
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
|
||||
|
@ -2236,6 +2264,10 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
// outputTensor = (input + 3) / 6.
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -2273,6 +2305,10 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
// result = min(maxVal, max(minVal, x))
|
||||
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal());
|
||||
|
|
Loading…
Reference in New Issue