mirror of https://github.com/llvm/torch-mlir
[Torch] Fix AtenSliceTensorOp::fold (#3345)
parent
405f884522
commit
a9edefb3cf
|
@ -3570,17 +3570,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
||||||
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
||||||
|
|
||||||
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||||
|
!inType.hasDtype() || !outType.hasDtype() ||
|
||||||
|
inType.getDtype() != outType.getDtype())
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||||
start.getValue().getSExtValue() == 0 &&
|
start.getValue().getSExtValue() == 0 &&
|
||||||
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
|
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
|
||||||
inType == outType)
|
inType == outType)
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
|
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
|
||||||
!inType.hasDtype() || !outType.hasDtype() ||
|
|
||||||
inType.getDtype() != outType.getDtype())
|
|
||||||
return nullptr;
|
|
||||||
|
|
||||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
Loading…
Reference in New Issue