mirror of https://github.com/llvm/torch-mlir
parent
307f49f566
commit
197ef4224b
|
@ -3389,14 +3389,15 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
||||
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
||||
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||
|
||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||
start.getValue().getSExtValue() == 0 &&
|
||||
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max())
|
||||
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
|
||||
inType == outType)
|
||||
return getOperand(0);
|
||||
|
||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||
!inType.hasDtype() || !outType.hasDtype() ||
|
||||
inType.getDtype() != outType.getDtype())
|
||||
|
|
Loading…
Reference in New Issue