mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for OptionalHasElement op (#3472)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3482/head
parent
d29ad4dfbd
commit
83bfb6fb19
|
@ -2733,4 +2733,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, tensorListResultType, input);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"OptionalHasElement", 15,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
if (binder.tensorResultType(resultType))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"result type bind failed");
|
||||
|
||||
Value input;
|
||||
bool output;
|
||||
if (!binder.tensorListOperand(input) || !binder.tensorOperand(input) ||
|
||||
!binder.optionalTensorListOperand(input) ||
|
||||
!binder.optionalTensorOperand(input))
|
||||
output = true;
|
||||
else
|
||||
output = false;
|
||||
|
||||
Value cstOutput = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr((int64_t)output));
|
||||
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr((int)torch_upstream::ScalarType::Bool));
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(false));
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
|
||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
SmallVector<Value>{cstOutput});
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
|
||||
binder.op, resultType, dataList, /*dtype=*/cstDtype,
|
||||
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1545,3 +1545,110 @@ func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !
|
|||
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_empty_none_input
|
||||
func.func @test_optional_has_element_empty_none_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE_0]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.OptionalHasElement"(%none) : (!torch.none) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_empty_no_input
|
||||
func.func @test_optional_has_element_empty_no_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%0 = torch.operator "onnx.OptionalHasElement"() : () -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_empty_optional_input
|
||||
func.func @test_optional_has_element_empty_optional_input(%arg0: !torch.optional<vtensor<[],si32>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional<vtensor<[],si32>>) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_optional_tensor_input
|
||||
func.func @test_optional_has_element_optional_tensor_input(%arg0: !torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_optional_list_tensor_input
|
||||
func.func @test_optional_has_element_optional_list_tensor_input(%arg0: !torch.optional<list<vtensor<[4],f32>>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional<list<vtensor<[4],f32>>>) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_tensor_input
|
||||
func.func @test_optional_has_element_tensor_input(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_has_element_list_tensor_input
|
||||
func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list<vtensor<[4],f32>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[],i1>
|
||||
%0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.list<vtensor<[4],f32>>) -> !torch.vtensor<[],i1>
|
||||
return %0 : !torch.vtensor<[],i1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue