Avoid Type Mismatch in Slice Folder (#3154)

Fixes issue #3153
pull/3155/head
zjgarvey 2024-04-12 13:43:45 -05:00 committed by GitHub
parent 307f49f566
commit 197ef4224b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 3 deletions

View File

@ -3389,14 +3389,15 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd()); IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep()); IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim()); IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (start && end && step && step.getValue().getSExtValue() == 1 && if (start && end && step && step.getValue().getSExtValue() == 1 &&
start.getValue().getSExtValue() == 0 && start.getValue().getSExtValue() == 0 &&
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max()) end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
inType == outType)
return getOperand(0); return getOperand(0);
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() || !inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype()) inType.getDtype() != outType.getDtype())