diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index bba8b7438..e3168004b 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -855,8 +855,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = cast(getTypeConverter()->convertType(op.getType())); auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); auto valuesType = cast(values.getType()); + int64_t valueRank = valuesType.getRank(); auto valuesShape = valuesType.getShape(); bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) { @@ -868,6 +868,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!getListConstructElements(indexList, indicesTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); + int64_t indexCnt = indicesTorchType.size(); auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTorchType); @@ -886,11 +887,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector scatterDimOperandDimMap; SmallVector insertedWindowDims; SmallVector updateWindowDims; - for (int64_t i = 0; i < maxIndexRank; ++i) { + for (int64_t i = 0; i < indexCnt; ++i) { scatterDimOperandDimMap.push_back(i); insertedWindowDims.push_back(i); } - for (int64_t i = maxIndexRank; i < inputRank; ++i) { + for (int64_t i = maxIndexRank; i < valueRank; ++i) { updateWindowDims.push_back(i); } auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(