mirror of https://github.com/llvm/torch-mlir
Fix scatter op bufferization to alway copy original tensor
parent
3d9ba5e525
commit
3510b2ba9d
|
@ -403,10 +403,9 @@ bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
|||
Value operand = opOperand->get();
|
||||
if (operand == updates())
|
||||
bbArgNumber = 0; // block arg 0 is `update`.
|
||||
else if (operand == original())
|
||||
bbArgNumber = 1; // block arg 1 is `original`.
|
||||
else {
|
||||
assert(operand == indices() &&
|
||||
bool isValidOperand = operand == indices() || operand == original();
|
||||
assert(isValidOperand &&
|
||||
"operand must belong to the current tm_tensor.scatter op");
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -61,7 +61,9 @@ func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tenso
|
|||
// CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> {
|
||||
// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32>
|
||||
// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32>
|
||||
// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32>
|
||||
// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32>
|
||||
// CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32>
|
||||
// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]]
|
||||
// CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) {
|
||||
// CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32):
|
||||
|
|
Loading…
Reference in New Issue