mirror of https://github.com/llvm/torch-mlir
Support `DerefineOp` in `RefinePublicReturn`.
parent
4847563bed
commit
a20422ce65
|
@ -62,7 +62,8 @@ class RefinePublicReturnPass
|
||||||
OpBuilder builder(returnOp);
|
OpBuilder builder(returnOp);
|
||||||
for (auto operand : returnOp.getOperands()) {
|
for (auto operand : returnOp.getOperands()) {
|
||||||
Value newOperand = operand;
|
Value newOperand = operand;
|
||||||
// Look through TensorStaticInfoCastOp's and CopyToNonValueTensorOp's.
|
// Look through TensorStaticInfoCastOp's, CopyToNonValueTensorOp's, and
|
||||||
|
// DerefineOp's.
|
||||||
for (;;) {
|
for (;;) {
|
||||||
if (auto cast = newOperand.getDefiningOp<TensorStaticInfoCastOp>()) {
|
if (auto cast = newOperand.getDefiningOp<TensorStaticInfoCastOp>()) {
|
||||||
newOperand = cast.getOperand();
|
newOperand = cast.getOperand();
|
||||||
|
@ -76,6 +77,8 @@ class RefinePublicReturnPass
|
||||||
if (users.size() != 1)
|
if (users.size() != 1)
|
||||||
break;
|
break;
|
||||||
newOperand = copy.getOperand();
|
newOperand = copy.getOperand();
|
||||||
|
} else if (auto derefine = newOperand.getDefiningOp<DerefineOp>()) {
|
||||||
|
newOperand = derefine.getOperand();
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,14 @@ func.func @basic(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
|
||||||
return %2 : !torch.tensor
|
return %2 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @refine_optional(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
|
||||||
|
// CHECK: return %[[ARG]] : !torch.vtensor<[2],f32>
|
||||||
|
func.func @refine_optional(%arg: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
|
||||||
|
%res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
|
||||||
|
return %res : !torch.optional<vtensor<[2],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @multiple_use_non_value_tensor(
|
// CHECK-LABEL: func.func @multiple_use_non_value_tensor(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
@ -34,6 +42,17 @@ func.func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.t
|
||||||
return %2 : !torch.tensor
|
return %2 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// No conversion on private function.
|
||||||
|
// CHECK-LABEL: func.func private @dont_refine_private(
|
||||||
|
// CHECK-SAME: %[[ARG:.+]]: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
|
||||||
|
// CHECK: %[[RES:.+]] = torch.derefine %[[ARG]] : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
|
||||||
|
// CHECK: return %[[RES]] : !torch.optional<vtensor<[2],f32>>
|
||||||
|
// CHECK: }
|
||||||
|
func.func private @dont_refine_private(%arg: !torch.vtensor<[2],f32>) -> !torch.optional<vtensor<[2],f32>> {
|
||||||
|
%res = torch.derefine %arg : !torch.vtensor<[2],f32> to !torch.optional<vtensor<[2],f32>>
|
||||||
|
return %res : !torch.optional<vtensor<[2],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Call to public function.
|
// Call to public function.
|
||||||
|
|
Loading…
Reference in New Issue