[ONNX] Add OnnxToTorch lowering for Optional, OptionalGetElement op (#3467)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3470/head
Vivek Khandelwal 2024-06-18 19:40:18 +05:30 committed by GitHub
parent 676fa8cc09
commit 822d763308
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 172 additions and 0 deletions

View File

@ -45,6 +45,18 @@ struct OpBinder {
return success();
}
ParseResult optionalTensorOperand(Value &value0) {
if (op->getNumOperands() != 1)
return failure();
value0 = op->getOperand(0);
auto ot = dyn_cast<Torch::OptionalType>(value0.getType());
if (!ot)
return failure();
if (!toValidTensorType(ot.getContainedType()))
return failure();
return success();
}
ParseResult tensorOperands(Value &value0, Value &value1) {
if (op->getNumOperands() != 2)
return failure();
@ -110,6 +122,21 @@ struct OpBinder {
return success();
}
ParseResult optionalTensorListOperand(Value &value0) {
if (op->getNumOperands() != 1)
return failure();
value0 = op->getOperand(0);
auto ot = dyn_cast<Torch::OptionalType>(value0.getType());
if (!ot)
return failure();
auto tt = dyn_cast<Torch::ListType>(ot.getContainedType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}
ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
if (idx >= op->getNumOperands())
return failure();
@ -144,6 +171,44 @@ struct OpBinder {
return success();
}
ParseResult optionalResultType(Torch::OptionalType &type0) {
if (op->getNumResults() != 1)
return failure();
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
if (!ot)
return failure();
type0 = ot;
return success();
}
ParseResult optionalTensorResultType(Torch::ValueTensorType &type0) {
if (op->getNumResults() != 1)
return failure();
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
if (!ot)
return failure();
auto t = toValidTensorType(ot.getContainedType());
if (!t)
return failure();
type0 = t;
return success();
}
ParseResult optionalTensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
if (!ot)
return failure();
auto tt = dyn_cast<Torch::ListType>(ot.getContainedType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
type0 = tt;
return success();
}
// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {

View File

@ -2672,4 +2672,63 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*cudnn_enabled=*/cstFalse);
return success();
});
patterns.onOp(
"Optional", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::OptionalType resultType;
Value input;
if (binder.getNumOperands() == 0)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for missing input element");
if (binder.tensorListOperand(input))
if (binder.tensorOperand(input))
return failure();
if (binder.optionalResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(binder.op, resultType,
input);
return success();
});
patterns.onOp("OptionalGetElement", 15,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ListType tensorListResultType;
Torch::ValueTensorType tensorResultType;
Value input;
if (binder.tensorListResultType(tensorListResultType)) {
if (binder.tensorResultType(tensorResultType))
return failure();
if (binder.optionalTensorOperand(input)) {
if (binder.tensorOperand(input))
return failure();
// It means the input is a tensor.
rewriter.replaceOp(binder.op, input);
return success();
}
// It means the input is an optional tensor.
rewriter.replaceOpWithNewOp<Torch::PrimUncheckedCastOp>(
binder.op, tensorResultType, input);
return success();
}
if (binder.optionalTensorListOperand(input)) {
if (binder.tensorListOperand(input))
return failure();
// It means the input is a tensor list.
rewriter.replaceOp(binder.op, input);
return success();
}
// It means the input is an optional tensor list.
rewriter.replaceOpWithNewOp<Torch::PrimUncheckedCastOp>(
binder.op, tensorListResultType, input);
return success();
});
}

View File

@ -1495,3 +1495,51 @@ func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32>
}
// -----
// CHECK-LABEL: @test_optional
func.func @test_optional(%arg0: !torch.list<vtensor<[5],f32>>) -> !torch.optional<list<vtensor<[5],f32>>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64} {
// CHECK: %[[RESULT:.*]] = torch.derefine %arg0 : !torch.list<vtensor<[5],f32>> to !torch.optional<list<vtensor<[5],f32>>>
// CHECK: return %[[RESULT]] : !torch.optional<list<vtensor<[5],f32>>>
%0 = torch.operator "onnx.Optional"(%arg0) : (!torch.list<vtensor<[5],f32>>) -> !torch.optional<list<vtensor<[5],f32>>>
return %0 : !torch.optional<list<vtensor<[5],f32>>>
}
// -----
// CHECK-LABEL: @test_optional_get_element_optional_sequence
func.func @test_optional_get_element_optional_sequence(%arg0: !torch.optional<list<vtensor<[4],si32>>>) -> !torch.list<vtensor<[4],si32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional<list<vtensor<[4],si32>>> -> !torch.list<vtensor<[4],si32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[4],si32>>
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional<list<vtensor<[4],si32>>>) -> !torch.list<vtensor<[4],si32>>
return %0 : !torch.list<vtensor<[4],si32>>
}
// -----
// CHECK-LABEL: @test_optional_get_element_optional_tensor
func.func @test_optional_get_element_optional_tensor(%arg0: !torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional<vtensor<[4],f32>> -> !torch.vtensor<[4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[4],f32>
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @test_optional_get_element_sequence
func.func @test_optional_get_element_sequence(%arg0: !torch.list<vtensor<[4],si32>>) -> !torch.list<vtensor<[4],si32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: return %arg0 : !torch.list<vtensor<[4],si32>>
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.list<vtensor<[4],si32>>) -> !torch.list<vtensor<[4],si32>>
return %0 : !torch.list<vtensor<[4],si32>>
}
// -----
// CHECK-LABEL: @test_optional_get_element_tensor
func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: return %arg0 : !torch.vtensor<[4],f32>
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}