Support `DerefineOp` in `RefinePublicReturn`.

pull/2328/head
Alexandre Rames 2023-07-18 07:32:26 -07:00 committed by Alexandre Rames
parent 4847563bed
commit a20422ce65
2 changed files with 23 additions and 1 deletions

View File

@ -62,7 +62,8 @@ class RefinePublicReturnPass
OpBuilder builder(returnOp); OpBuilder builder(returnOp);
for (auto operand : returnOp.getOperands()) { for (auto operand : returnOp.getOperands()) {
Value newOperand = operand; Value newOperand = operand;
// Look through TensorStaticInfoCastOp's and CopyToNonValueTensorOp's. // Look through TensorStaticInfoCastOp's, CopyToNonValueTensorOp's, and
// DerefineOp's.
for (;;) { for (;;) {
if (auto cast = newOperand.getDefiningOp<TensorStaticInfoCastOp>()) { if (auto cast = newOperand.getDefiningOp<TensorStaticInfoCastOp>()) {
newOperand = cast.getOperand(); newOperand = cast.getOperand();
@ -76,6 +77,8 @@ class RefinePublicReturnPass
if (users.size() != 1) if (users.size() != 1)
break; break;
newOperand = copy.getOperand(); newOperand = copy.getOperand();
} else if (auto derefine = newOperand.getDefiningOp<DerefineOp>()) {
newOperand = derefine.getOperand();
} else { } else {
break; break;
} }

View File

@ -9,6 +9,14 @@ func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
return %2 : !torch.tensor return %2 : !torch.tensor
} }
// CHECK-LABEL: func.func @refine_optional(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[2],f32>
func.func @refine_optional(%arg: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
%res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
return %res : !torch.optional<vtensor<[2],f32>>
}
// CHECK-LABEL: func.func @multiple_use_non_value_tensor( // CHECK-LABEL: func.func @multiple_use_non_value_tensor(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
@ -34,6 +42,17 @@ func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.t
return %2 : !torch.tensor return %2 : !torch.tensor
} }
// No conversion on private function.
// CHECK-LABEL: func.func private @dont_refine_private(
// CHECK-SAME: %[[ARG:.+]]: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
// CHECK: %[[RES:.+]] = torch.derefine %[[ARG]] : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
// CHECK: return %[[RES]] : !torch.optional<vtensor<[2],f32>>
// CHECK: }
func.func private @dont_refine_private(%arg: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
%res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
return %res : !torch.optional<vtensor<[2],f32>>
}
// ----- // -----
// Call to public function. // Call to public function.