diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index e154f5cb9..861a861c5 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -932,9 +932,12 @@ public: // 2.) `values` is mapped to `updates` in scatter op. // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; + // If accumulate == false, the behavior is undefined if the indicies aren't + // unique. + bool uniqueIndices = !accumulate; Value scatterOp = createTMTensorScatterOp( rewriter, loc, values, indices, input, indicesMap, - /*uniqueIndices=*/false, + /*uniqueIndices=*/uniqueIndices, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement;