diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 4adf61346..5109a8c57 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -71,7 +71,9 @@ class RefinePublicReturnPass // If the return (or transitively other ops) are not the only users, // then we can't be sure that the tensor hasn't been mutated, so stop // here. - if (!llvm::hasSingleElement(copy->getUsers())) + SetVector users(copy->getUsers().begin(), + copy->getUsers().end()); + if (users.size() != 1) break; newOperand = copy.getOperand(); } else { diff --git a/test/Dialect/Torch/refine-public-return.mlir b/test/Dialect/Torch/refine-public-return.mlir index 0cb97d1bd..ad810ec97 100644 --- a/test/Dialect/Torch/refine-public-return.mlir +++ b/test/Dialect/Torch/refine-public-return.mlir @@ -59,3 +59,16 @@ func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { ^bb2: return %arg0 : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func.func @return_multiple_copies_of_tensor( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) { +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor +// CHECK: %[[TO_TENSOR:.*]] = torch.copy.to_tensor %[[CAST]] : !torch.tensor +// CHECK: return %[[ARG]], %[[ARG]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> +func.func @return_multiple_copies_of_tensor(%arg0: !torch.vtensor<[],f32>) -> (!torch.tensor, !torch.tensor) { + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor + %1 = torch.copy.to_tensor %0 : !torch.tensor + return %1, %1 : !torch.tensor, !torch.tensor +}