From 197ef4224bc41471acd4ccfd8694ed8e0842e716 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:43:45 -0500 Subject: [PATCH] Avoid Type Mismatch in Slice Folder (#3154) Fixes issue #3153 --- lib/Dialect/Torch/IR/TorchOps.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d16a41aa3..22fece712 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3389,14 +3389,15 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { IntegerAttr end = dyn_cast_or_null(adaptor.getEnd()); IntegerAttr step = dyn_cast_or_null(adaptor.getStep()); IntegerAttr dim = dyn_cast_or_null(adaptor.getDim()); + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); if (start && end && step && step.getValue().getSExtValue() == 1 && start.getValue().getSExtValue() == 0 && - end.getValue().getSExtValue() == std::numeric_limits::max()) + end.getValue().getSExtValue() == std::numeric_limits::max() && + inType == outType) return getOperand(0); - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || !inType.hasDtype() || !outType.hasDtype() || inType.getDtype() != outType.getDtype())