mirror of https://github.com/llvm/torch-mlir
Add support for returning more than one copy of the same tensor (#1228)
One of the simplifications made by the pass `RefinePublicReturn` currently only happens if the tensor in question only has one user. However, the current method of checking this does not correctly handle the case of a user having multiple uses of the same tensor. This commit makes sure only unique users are considered.pull/1247/head
parent
1a7fc3915c
commit
9bc606c384
|
@ -71,7 +71,9 @@ class RefinePublicReturnPass
|
||||||
// If the return (or transitively other ops) are not the only users,
|
// If the return (or transitively other ops) are not the only users,
|
||||||
// then we can't be sure that the tensor hasn't been mutated, so stop
|
// then we can't be sure that the tensor hasn't been mutated, so stop
|
||||||
// here.
|
// here.
|
||||||
if (!llvm::hasSingleElement(copy->getUsers()))
|
SetVector<Operation *> users(copy->getUsers().begin(),
|
||||||
|
copy->getUsers().end());
|
||||||
|
if (users.size() != 1)
|
||||||
break;
|
break;
|
||||||
newOperand = copy.getOperand();
|
newOperand = copy.getOperand();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -59,3 +59,16 @@ func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
^bb2:
|
^bb2:
|
||||||
return %arg0 : tensor<*xf32>
|
return %arg0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @return_multiple_copies_of_tensor(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) {
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[TO_TENSOR:.*]] = torch.copy.to_tensor %[[CAST]] : !torch.tensor
|
||||||
|
// CHECK: return %[[ARG]], %[[ARG]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
|
||||||
|
func.func @return_multiple_copies_of_tensor(%arg0: !torch.vtensor<[],f32>) -> (!torch.tensor, !torch.tensor) {
|
||||||
|
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor
|
||||||
|
%1 = torch.copy.to_tensor %0 : !torch.tensor
|
||||||
|
return %1, %1 : !torch.tensor, !torch.tensor
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue