mirror of https://github.com/llvm/torch-mlir
[Torch] Fix bug of DecomposeAtenSelectIntOp (#3087)
Fix bug of DecomposeAtenSelectIntOp. Because it may use resultTy when resultTy has not been inferred. ``` auto resultTy = op.getType().cast<BaseTensorType>(); if (sliceTy.getSizes().size() == resultTy.getSizes().size()) { rewriter.replaceOp(op, slice); return success(); } ``` So I add restriction.pull/3098/head
parent
76080936d4
commit
da88efad89
|
@ -764,6 +764,12 @@ public:
|
|||
Value dim = op.getDim();
|
||||
Value self = op.getSelf();
|
||||
|
||||
auto resultTy = op.getType().cast<BaseTensorType>();
|
||||
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have sizes and dtype");
|
||||
}
|
||||
|
||||
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
|
@ -785,7 +791,6 @@ public:
|
|||
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
||||
|
||||
auto sliceTy = cast<BaseTensorType>(slice.getType());
|
||||
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
|
||||
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
|
||||
rewriter.replaceOp(op, slice);
|
||||
return success();
|
||||
|
|
Loading…
Reference in New Issue