[Torch Dialect] require dtype exists when decompose to aten.where.self (#2094)

* [Torch Dialect] require dtype exists when decompose to aten.where.self

* update
pull/2133/head
Yuanqiang Liu 2023-05-18 00:04:26 +08:00 committed by GitHub
parent 0302cf1d92
commit e98f2ba04a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 42 additions and 6 deletions

View File

@ -129,20 +129,26 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
Type resultType, Value scalar, Value sizeList) {
BaseTensorType resultType, Value scalar,
Value sizeList) {
assert(resultType.hasDtype() && "result must have dtype");
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
return rewriter.create<AtenFullOp>(
loc, resultType, sizeList, scalar, /*dtype=*/noneVal, /*layout=*/noneVal,
/*device=*/noneVal, /*memory_format=*/noneVal);
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
return rewriter.create<AtenFullOp>(loc, resultType, sizeList, scalar, dtype,
/*layout=*/noneVal,
/*device=*/noneVal,
/*memory_format=*/noneVal);
}
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) {
assert(inputType.hasDtype() && "input must have dtype");
SmallVector<int64_t> sizes;
Type rank0TensorTy = inputType.getWithSizesAndDtype(
ArrayRef(sizes), inputType.getOptionalDtype());
BaseTensorType rank0TensorTy =
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
.cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{});
@ -895,6 +901,10 @@ public:
LogicalResult matchAndRewrite(AtenRelu6Op op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value relu6 = getRelu6Results(rewriter, loc, op.getSelf());
rewriter.replaceOp(op, relu6);
return success();
@ -944,6 +954,9 @@ public:
Value input = op.getSelf();
Value negativeSlope = op.getNegativeSlope();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
@ -978,6 +991,9 @@ public:
Value input = op.getSelf();
Value negativeSlope = op.getNegativeSlope();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
bool selfIsResult = false;
if (!matchPattern(op.getSelfIsResult(),
@ -1372,6 +1388,9 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
@ -1391,6 +1410,9 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
op.getSelf(), otherTensor);
@ -1409,6 +1431,9 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, op.getCondition(),
selfTensor, op.getOther());
@ -1427,6 +1452,9 @@ public:
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value mask = op.getMask();
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
@ -2236,6 +2264,10 @@ public:
Location loc = op.getLoc();
Value input = op.getSelf();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
// outputTensor = (input + 3) / 6.
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
@ -2273,6 +2305,10 @@ public:
Location loc = op.getLoc();
Value input = op.getSelf();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
// result = min(maxVal, max(minVal, x))
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal());