From e88faf08ff48742dd5e728fb977ea05611bdcc68 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:48:34 -0800 Subject: [PATCH] Create scatter op with unique indicies (#3853) For the op `index_put_`, if accumulate == false, the behavior is undefined if the indicies aren't unique (https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html). So, when converting `AtenIndexPutHackedTwinOp` to a TMTensor scatter op, mark the indices as unique if when `accumulate == false`. This should have no functional effect (unless users are relying on UB) and assuming unique indices has the benefit of unlocking better optimizations in further compiler stages. Signed-off-by: Ian Wood --- lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index e154f5cb9..861a861c5 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -932,9 +932,12 @@ public: // 2.) `values` is mapped to `updates` in scatter op. // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; + // If accumulate == false, the behavior is undefined if the indicies aren't + // unique. + bool uniqueIndices = !accumulate; Value scatterOp = createTMTensorScatterOp( rewriter, loc, values, indices, input, indicesMap, - /*uniqueIndices=*/false, + /*uniqueIndices=*/uniqueIndices, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement;