[Torch] Fix bug of DecomposeAtenSelectIntOp (#3087)

Fix bug of DecomposeAtenSelectIntOp. Because it may use resultTy when
resultTy has not been inferred.

```
    auto resultTy = op.getType().cast<BaseTensorType>();
    if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
      rewriter.replaceOp(op, slice);
      return success();
    }

```

So I add restriction.
pull/3098/head
Xinyu Yang 2024-04-01 21:25:02 +08:00 committed by GitHub
parent 76080936d4
commit da88efad89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -764,6 +764,12 @@ public:
Value dim = op.getDim(); Value dim = op.getDim();
Value self = op.getSelf(); Value self = op.getSelf();
auto resultTy = op.getType().cast<BaseTensorType>();
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have sizes and dtype");
}
// convert `start` to non-negative: start += int(start < 0) * dimSize // convert `start` to non-negative: start += int(start < 0) * dimSize
Value zero = Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
@ -785,7 +791,6 @@ public:
op.getSelf(), dim, start, startPlusOne, /*step=*/one); op.getSelf(), dim, start, startPlusOne, /*step=*/one);
auto sliceTy = cast<BaseTensorType>(slice.getType()); auto sliceTy = cast<BaseTensorType>(slice.getType());
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) { if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
rewriter.replaceOp(op, slice); rewriter.replaceOp(op, slice);
return success(); return success();