mirror of https://github.com/llvm/torch-mlir
Create scatter op with unique indicies (#3853)
For the op `index_put_`, if accumulate == false, the behavior is undefined if the indicies aren't unique (https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html). So, when converting `AtenIndexPutHackedTwinOp` to a TMTensor scatter op, mark the indices as unique if when `accumulate == false`. This should have no functional effect (unless users are relying on UB) and assuming unique indices has the benefit of unlocking better optimizations in further compiler stages. Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>pull/3804/head
parent
b75d0e3f8b
commit
e88faf08ff
|
@ -932,9 +932,12 @@ public:
|
||||||
// 2.) `values` is mapped to `updates` in scatter op.
|
// 2.) `values` is mapped to `updates` in scatter op.
|
||||||
// 3.) `input` is mapped to `original` in scatter op.
|
// 3.) `input` is mapped to `original` in scatter op.
|
||||||
bool invalidInputTypeFound = false;
|
bool invalidInputTypeFound = false;
|
||||||
|
// If accumulate == false, the behavior is undefined if the indicies aren't
|
||||||
|
// unique.
|
||||||
|
bool uniqueIndices = !accumulate;
|
||||||
Value scatterOp = createTMTensorScatterOp(
|
Value scatterOp = createTMTensorScatterOp(
|
||||||
rewriter, loc, values, indices, input, indicesMap,
|
rewriter, loc, values, indices, input, indicesMap,
|
||||||
/*uniqueIndices=*/false,
|
/*uniqueIndices=*/uniqueIndices,
|
||||||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||||
Value inputElement) {
|
Value inputElement) {
|
||||||
Value yieldValue = valuesElement;
|
Value yieldValue = valuesElement;
|
||||||
|
|
Loading…
Reference in New Issue