Fix scatter op bufferization to alway copy original tensor

pull/657/head snapshot-20220310.315
Yi Zhang 2022-03-09 15:53:30 -05:00
parent 3d9ba5e525
commit 3510b2ba9d
2 changed files with 4 additions and 3 deletions

View File

@ -403,10 +403,9 @@ bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
Value operand = opOperand->get(); Value operand = opOperand->get();
if (operand == updates()) if (operand == updates())
bbArgNumber = 0; // block arg 0 is `update`. bbArgNumber = 0; // block arg 0 is `update`.
else if (operand == original())
bbArgNumber = 1; // block arg 1 is `original`.
else { else {
assert(operand == indices() && bool isValidOperand = operand == indices() || operand == original();
assert(isValidOperand &&
"operand must belong to the current tm_tensor.scatter op"); "operand must belong to the current tm_tensor.scatter op");
return true; return true;
} }

View File

@ -61,7 +61,9 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> // CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> // CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):