From a9edefb3cfaf27bb9e7c7a4298588a2c9f880344 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Thu, 16 May 2024 11:42:43 +0800 Subject: [PATCH] [Torch] Fix AtenSliceTensorOp::fold (#3345) --- lib/Dialect/Torch/IR/TorchOps.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 3a3a16fa3..bdd794924 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3570,17 +3570,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto inType = dyn_cast(getOperand(0).getType()); auto outType = dyn_cast(getResult().getType()); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !inType.hasDtype() || !outType.hasDtype() || + inType.getDtype() != outType.getDtype()) + return nullptr; + if (start && end && step && step.getValue().getSExtValue() == 1 && start.getValue().getSExtValue() == 0 && end.getValue().getSExtValue() == std::numeric_limits::max() && inType == outType) return getOperand(0); - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || - !inType.hasDtype() || !outType.hasDtype() || - inType.getDtype() != outType.getDtype()) - return nullptr; - if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr;