From 822d763308ac885dd626fdd1ef8f00806a2b9d78 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 18 Jun 2024 19:40:18 +0530 Subject: [PATCH] [ONNX] Add OnnxToTorch lowering for Optional, OptionalGetElement op (#3467) Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchOnnxToTorch/Patterns.h | 65 +++++++++++++++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 59 +++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 48 ++++++++++++++ 3 files changed, 172 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index f296b6dfe..90871110d 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -45,6 +45,18 @@ struct OpBinder { return success(); } + ParseResult optionalTensorOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto ot = dyn_cast(value0.getType()); + if (!ot) + return failure(); + if (!toValidTensorType(ot.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorOperands(Value &value0, Value &value1) { if (op->getNumOperands() != 2) return failure(); @@ -110,6 +122,21 @@ struct OpBinder { return success(); } + ParseResult optionalTensorListOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto ot = dyn_cast(value0.getType()); + if (!ot) + return failure(); + auto tt = dyn_cast(ot.getContainedType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) { if (idx >= op->getNumOperands()) return failure(); @@ -144,6 +171,44 @@ struct OpBinder { return success(); } + ParseResult optionalResultType(Torch::OptionalType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto ot = dyn_cast(op->getResult(0).getType()); + if (!ot) + return failure(); + type0 = ot; + return success(); + } + + ParseResult optionalTensorResultType(Torch::ValueTensorType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto ot = dyn_cast(op->getResult(0).getType()); + if (!ot) + return failure(); + auto t = toValidTensorType(ot.getContainedType()); + if (!t) + return failure(); + type0 = t; + return success(); + } + + ParseResult optionalTensorListResultType(Torch::ListType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto ot = dyn_cast(op->getResult(0).getType()); + if (!ot) + return failure(); + auto tt = dyn_cast(ot.getContainedType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + type0 = tt; + return success(); + } + // The importer imports Onnx.GraphProto attributes as regions attached to the // op. ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 6d4ea74f0..5485f931d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2672,4 +2672,63 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*cudnn_enabled=*/cstFalse); return success(); }); + patterns.onOp( + "Optional", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::OptionalType resultType; + Value input; + + if (binder.getNumOperands() == 0) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for missing input element"); + + if (binder.tensorListOperand(input)) + if (binder.tensorOperand(input)) + return failure(); + + if (binder.optionalResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + input); + return success(); + }); + patterns.onOp("OptionalGetElement", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType tensorListResultType; + Torch::ValueTensorType tensorResultType; + Value input; + + if (binder.tensorListResultType(tensorListResultType)) { + if (binder.tensorResultType(tensorResultType)) + return failure(); + + if (binder.optionalTensorOperand(input)) { + if (binder.tensorOperand(input)) + return failure(); + + // It means the input is a tensor. + rewriter.replaceOp(binder.op, input); + return success(); + } + + // It means the input is an optional tensor. + rewriter.replaceOpWithNewOp( + binder.op, tensorResultType, input); + return success(); + } + + if (binder.optionalTensorListOperand(input)) { + if (binder.tensorListOperand(input)) + return failure(); + + // It means the input is a tensor list. + rewriter.replaceOp(binder.op, input); + return success(); + } + + // It means the input is an optional tensor list. + rewriter.replaceOpWithNewOp( + binder.op, tensorListResultType, input); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 19c519082..8ed1a9a91 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1495,3 +1495,51 @@ func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32> %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> return %0 : !torch.vtensor<[3,4,2,2],f32> } + +// ----- + +// CHECK-LABEL: @test_optional +func.func @test_optional(%arg0: !torch.list>) -> !torch.optional>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64} { + // CHECK: %[[RESULT:.*]] = torch.derefine %arg0 : !torch.list> to !torch.optional>> + // CHECK: return %[[RESULT]] : !torch.optional>> + %0 = torch.operator "onnx.Optional"(%arg0) : (!torch.list>) -> !torch.optional>> + return %0 : !torch.optional>> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_optional_sequence +func.func @test_optional_get_element_optional_sequence(%arg0: !torch.optional>>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional>> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional>>) -> !torch.list> + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_optional_tensor +func.func @test_optional_get_element_optional_tensor(%arg0: !torch.optional>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional> -> !torch.vtensor<[4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_sequence +func.func @test_optional_get_element_sequence(%arg0: !torch.list>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.list> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.list>) -> !torch.list> + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_tensor +func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +}