Add support for returning more than one copy of the same tensor (#1228)

One of the simplifications made by the pass `RefinePublicReturn`
currently only happens if the tensor in question only has one
user. However, the current method of checking this does not correctly
handle the case of a user having multiple uses of the same
tensor. This commit makes sure only unique users are considered.
pull/1247/head
Ramiro Leal-Cavazos 2022-08-18 15:41:45 -07:00 committed by GitHub
parent 1a7fc3915c
commit 9bc606c384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -71,7 +71,9 @@ class RefinePublicReturnPass
// If the return (or transitively other ops) are not the only users, // If the return (or transitively other ops) are not the only users,
// then we can't be sure that the tensor hasn't been mutated, so stop // then we can't be sure that the tensor hasn't been mutated, so stop
// here. // here.
if (!llvm::hasSingleElement(copy->getUsers())) SetVector<Operation *> users(copy->getUsers().begin(),
copy->getUsers().end());
if (users.size() != 1)
break; break;
newOperand = copy.getOperand(); newOperand = copy.getOperand();
} else { } else {

View File

@ -59,3 +59,16 @@ func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
^bb2: ^bb2:
return %arg0 : tensor<*xf32> return %arg0 : tensor<*xf32>
} }
// -----
// CHECK-LABEL: func.func @return_multiple_copies_of_tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor
// CHECK: %[[TO_TENSOR:.*]] = torch.copy.to_tensor %[[CAST]] : !torch.tensor
// CHECK: return %[[ARG]], %[[ARG]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
func.func @return_multiple_copies_of_tensor(%arg0: !torch.vtensor<[],f32>) -> (!torch.tensor, !torch.tensor) {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
return %1, %1 : !torch.tensor, !torch.tensor
}