[Torch] Fix AtenSliceTensorOp::fold (#3345)

pull/3373/head
Xinyu Yang 2024-05-16 11:42:43 +08:00 committed by GitHub
parent 405f884522
commit a9edefb3cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 5 deletions

View File

@ -3570,17 +3570,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype())
return nullptr;
if (start && end && step && step.getValue().getSExtValue() == 1 &&
start.getValue().getSExtValue() == 0 &&
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
inType == outType)
return getOperand(0);
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
return nullptr;