mirror of https://github.com/llvm/torch-mlir
[Torch] Fix AtenSliceTensorOp::fold (#3345)
parent
405f884522
commit
a9edefb3cf
|
@ -3570,17 +3570,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
||||
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
||||
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||
!inType.hasDtype() || !outType.hasDtype() ||
|
||||
inType.getDtype() != outType.getDtype())
|
||||
return nullptr;
|
||||
|
||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||
start.getValue().getSExtValue() == 0 &&
|
||||
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
|
||||
inType == outType)
|
||||
return getOperand(0);
|
||||
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||
!inType.hasDtype() || !outType.hasDtype() ||
|
||||
inType.getDtype() != outType.getDtype())
|
||||
return nullptr;
|
||||
|
||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
||||
return nullptr;
|
||||
|
|
Loading…
Reference in New Issue