mirror of https://github.com/llvm/torch-mlir
Fix scatter op bufferization to alway copy original tensor
parent
3d9ba5e525
commit
3510b2ba9d
|
@ -403,10 +403,9 @@ bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
Value operand = opOperand->get();
|
Value operand = opOperand->get();
|
||||||
if (operand == updates())
|
if (operand == updates())
|
||||||
bbArgNumber = 0; // block arg 0 is `update`.
|
bbArgNumber = 0; // block arg 0 is `update`.
|
||||||
else if (operand == original())
|
|
||||||
bbArgNumber = 1; // block arg 1 is `original`.
|
|
||||||
else {
|
else {
|
||||||
assert(operand == indices() &&
|
bool isValidOperand = operand == indices() || operand == original();
|
||||||
|
assert(isValidOperand &&
|
||||||
"operand must belong to the current tm_tensor.scatter op");
|
"operand must belong to the current tm_tensor.scatter op");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,9 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
|
||||||
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
|
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
|
||||||
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||||
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||||
|
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||||
|
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||||
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||||
|
|
Loading…
Reference in New Issue