diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 5109a8c57..cfa4e40ee 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -62,7 +62,8 @@ class RefinePublicReturnPass OpBuilder builder(returnOp); for (auto operand : returnOp.getOperands()) { Value newOperand = operand; - // Look through TensorStaticInfoCastOp's and CopyToNonValueTensorOp's. + // Look through TensorStaticInfoCastOp's, CopyToNonValueTensorOp's, and + // DerefineOp's. for (;;) { if (auto cast = newOperand.getDefiningOp()) { newOperand = cast.getOperand(); @@ -76,6 +77,8 @@ class RefinePublicReturnPass if (users.size() != 1) break; newOperand = copy.getOperand(); + } else if (auto derefine = newOperand.getDefiningOp()) { + newOperand = derefine.getOperand(); } else { break; } diff --git a/test/Dialect/Torch/refine-public-return.mlir b/test/Dialect/Torch/refine-public-return.mlir index ad810ec97..b3a225962 100644 --- a/test/Dialect/Torch/refine-public-return.mlir +++ b/test/Dialect/Torch/refine-public-return.mlir @@ -9,6 +9,14 @@ func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { return %2 : !torch.tensor } +// CHECK-LABEL: func.func @refine_optional( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> { +// CHECK: return %[[ARG]] : !torch.vtensor<[2],f32> +func.func @refine_optional(%arg: !torch.vtensor<[2],f32>) -> !torch.optional> { + %res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional> + return %res : !torch.optional> +} + // CHECK-LABEL: func.func @multiple_use_non_value_tensor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor { @@ -34,6 +42,17 @@ func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.t return %2 : !torch.tensor } +// No conversion on private function. +// CHECK-LABEL: func.func private @dont_refine_private( +// CHECK-SAME: %[[ARG:.+]]: !torch.vtensor<[2],f32>) -> !torch.optional> { +// CHECK: %[[RES:.+]] = torch.derefine %[[ARG]] : !torch.vtensor<[2],f32> to !torch.optional> +// CHECK: return %[[RES]] : !torch.optional> +// CHECK: } +func.func private @dont_refine_private(%arg: !torch.vtensor<[2],f32>) -> !torch.optional> { + %res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional> + return %res : !torch.optional> +} + // ----- // Call to public function.