diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 547170cd5..0c7955b1e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2733,4 +2733,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, tensorListResultType, input); return success(); }); + patterns.onOp( + "OptionalHasElement", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + if (binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure(binder.op, + "result type bind failed"); + + Value input; + bool output; + if (!binder.tensorListOperand(input) || !binder.tensorOperand(input) || + !binder.optionalTensorListOperand(input) || + !binder.optionalTensorOperand(input)) + output = true; + else + output = false; + + Value cstOutput = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr((int64_t)output)); + Value cstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr((int)torch_upstream::ScalarType::Bool)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + Value cstNone = rewriter.create(binder.getLoc()); + + Value dataList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{cstOutput}); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, dataList, /*dtype=*/cstDtype, + /*layout=*/cstNone, /*requires_grad=*/cstFalse); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index be07ac634..c60ac654f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1545,3 +1545,110 @@ func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> ! %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> return %0 : !torch.vtensor<[4],f32> } + +// ----- + +// CHECK-LABEL: @test_optional_has_element_empty_none_input +func.func @test_optional_has_element_empty_none_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE_0]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %none = torch.constant.none + %0 = torch.operator "onnx.OptionalHasElement"(%none) : (!torch.none) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_empty_no_input +func.func @test_optional_has_element_empty_no_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"() : () -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_empty_optional_input +func.func @test_optional_has_element_empty_optional_input(%arg0: !torch.optional>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_optional_tensor_input +func.func @test_optional_has_element_optional_tensor_input(%arg0: !torch.optional>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_optional_list_tensor_input +func.func @test_optional_has_element_optional_list_tensor_input(%arg0: !torch.optional>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional>>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_tensor_input +func.func @test_optional_has_element_tensor_input(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_list_tensor_input +func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.list>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +}