mirror of https://github.com/llvm/torch-mlir
Support `DerefineOp` in `RefinePublicReturn`.
parent
4847563bed
commit
a20422ce65
|
@ -62,7 +62,8 @@ class RefinePublicReturnPass
|
|||
OpBuilder builder(returnOp);
|
||||
for (auto operand : returnOp.getOperands()) {
|
||||
Value newOperand = operand;
|
||||
// Look through TensorStaticInfoCastOp's and CopyToNonValueTensorOp's.
|
||||
// Look through TensorStaticInfoCastOp's, CopyToNonValueTensorOp's, and
|
||||
// DerefineOp's.
|
||||
for (;;) {
|
||||
if (auto cast = newOperand.getDefiningOp<TensorStaticInfoCastOp>()) {
|
||||
newOperand = cast.getOperand();
|
||||
|
@ -76,6 +77,8 @@ class RefinePublicReturnPass
|
|||
if (users.size() != 1)
|
||||
break;
|
||||
newOperand = copy.getOperand();
|
||||
} else if (auto derefine = newOperand.getDefiningOp<DerefineOp>()) {
|
||||
newOperand = derefine.getOperand();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,14 @@ func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
|||
return %2 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @refine_optional(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor<[2],f32>
|
||||
func.func @refine_optional(%arg: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
|
||||
%res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
|
||||
return %res : !torch.optional<vtensor<[2],f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @multiple_use_non_value_tensor(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||
|
@ -34,6 +42,17 @@ func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.t
|
|||
return %2 : !torch.tensor
|
||||
}
|
||||
|
||||
// No conversion on private function.
|
||||
// CHECK-LABEL: func.func private @dont_refine_private(
|
||||
// CHECK-SAME: %[[ARG:.+]]: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
|
||||
// CHECK: %[[RES:.+]] = torch.derefine %[[ARG]] : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
|
||||
// CHECK: return %[[RES]] : !torch.optional<vtensor<[2],f32>>
|
||||
// CHECK: }
|
||||
func.func private @dont_refine_private(%arg: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
|
||||
%res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
|
||||
return %res : !torch.optional<vtensor<[2],f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Call to public function.
|
||||
|
|
Loading…
Reference in New Issue