mirror of https://github.com/llvm/torch-mlir
[Torch] Fix bug of DecomposeAtenSelectIntOp (#3087)
Fix bug of DecomposeAtenSelectIntOp. Because it may use resultTy when resultTy has not been inferred. ``` auto resultTy = op.getType().cast<BaseTensorType>(); if (sliceTy.getSizes().size() == resultTy.getSizes().size()) { rewriter.replaceOp(op, slice); return success(); } ``` So I add restriction.pull/3098/head
parent
76080936d4
commit
da88efad89
|
@ -764,6 +764,12 @@ public:
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
|
|
||||||
|
auto resultTy = op.getType().cast<BaseTensorType>();
|
||||||
|
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected result type to have sizes and dtype");
|
||||||
|
}
|
||||||
|
|
||||||
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
||||||
Value zero =
|
Value zero =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
@ -785,7 +791,6 @@ public:
|
||||||
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
||||||
|
|
||||||
auto sliceTy = cast<BaseTensorType>(slice.getType());
|
auto sliceTy = cast<BaseTensorType>(slice.getType());
|
||||||
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
|
|
||||||
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
|
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
|
||||||
rewriter.replaceOp(op, slice);
|
rewriter.replaceOp(op, slice);
|
||||||
return success();
|
return success();
|
||||||
|
|
Loading…
Reference in New Issue