mirror of https://github.com/llvm/torch-mlir
parent
307f49f566
commit
197ef4224b
|
@ -3389,14 +3389,15 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
||||||
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
||||||
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||||
|
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||||
|
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||||
|
|
||||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||||
start.getValue().getSExtValue() == 0 &&
|
start.getValue().getSExtValue() == 0 &&
|
||||||
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max())
|
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max() &&
|
||||||
|
inType == outType)
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
|
|
||||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
|
||||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||||
!inType.hasDtype() || !outType.hasDtype() ||
|
!inType.hasDtype() || !outType.hasDtype() ||
|
||||||
inType.getDtype() != outType.getDtype())
|
inType.getDtype() != outType.getDtype())
|
||||||
|
|
Loading…
Reference in New Issue