[ONNX] Add OnnxToTorch lowering for OptionalHasElement op (#3472)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3482/head
Vivek Khandelwal 2024-06-21 11:19:00 +05:30 committed by GitHub
parent d29ad4dfbd
commit 83bfb6fb19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 144 additions and 0 deletions

View File

@ -2733,4 +2733,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, tensorListResultType, input);
return success();
});
patterns.onOp(
"OptionalHasElement", 15,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
if (binder.tensorResultType(resultType))
return rewriter.notifyMatchFailure(binder.op,
"result type bind failed");
Value input;
bool output;
if (!binder.tensorListOperand(input) || !binder.tensorOperand(input) ||
!binder.optionalTensorListOperand(input) ||
!binder.optionalTensorOperand(input))
output = true;
else
output = false;
Value cstOutput = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr((int64_t)output));
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr((int)torch_upstream::ScalarType::Bool));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(false));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{cstOutput});
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
binder.op, resultType, dataList, /*dtype=*/cstDtype,
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
return success();
});
}

View File

@ -1545,3 +1545,110 @@ func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @test_optional_has_element_empty_none_input
func.func @test_optional_has_element_empty_none_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE_0]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%none = torch.constant.none
%0 = torch.operator "onnx.OptionalHasElement"(%none) : (!torch.none) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: @test_optional_has_element_empty_no_input
func.func @test_optional_has_element_empty_no_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.OptionalHasElement"() : () -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: @test_optional_has_element_empty_optional_input
func.func @test_optional_has_element_empty_optional_input(%arg0: !torch.optional<vtensor<[],si32>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional<vtensor<[],si32>>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: @test_optional_has_element_optional_tensor_input
func.func @test_optional_has_element_optional_tensor_input(%arg0: !torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: @test_optional_has_element_optional_list_tensor_input
func.func @test_optional_has_element_optional_list_tensor_input(%arg0: !torch.optional<list<vtensor<[4],f32>>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional<list<vtensor<[4],f32>>>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: @test_optional_has_element_tensor_input
func.func @test_optional_has_element_tensor_input(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
// -----
// CHECK-LABEL: @test_optional_has_element_list_tensor_input
func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list<vtensor<[4],f32>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.list<vtensor<[4],f32>>) -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}