mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for Optional, OptionalGetElement op (#3467)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3470/head
parent
676fa8cc09
commit
822d763308
|
@ -45,6 +45,18 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult optionalTensorOperand(Value &value0) {
|
||||
if (op->getNumOperands() != 1)
|
||||
return failure();
|
||||
value0 = op->getOperand(0);
|
||||
auto ot = dyn_cast<Torch::OptionalType>(value0.getType());
|
||||
if (!ot)
|
||||
return failure();
|
||||
if (!toValidTensorType(ot.getContainedType()))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorOperands(Value &value0, Value &value1) {
|
||||
if (op->getNumOperands() != 2)
|
||||
return failure();
|
||||
|
@ -110,6 +122,21 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult optionalTensorListOperand(Value &value0) {
|
||||
if (op->getNumOperands() != 1)
|
||||
return failure();
|
||||
value0 = op->getOperand(0);
|
||||
auto ot = dyn_cast<Torch::OptionalType>(value0.getType());
|
||||
if (!ot)
|
||||
return failure();
|
||||
auto tt = dyn_cast<Torch::ListType>(ot.getContainedType());
|
||||
if (!tt)
|
||||
return failure();
|
||||
if (!toValidTensorType(tt.getContainedType()))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
|
||||
if (idx >= op->getNumOperands())
|
||||
return failure();
|
||||
|
@ -144,6 +171,44 @@ struct OpBinder {
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult optionalResultType(Torch::OptionalType &type0) {
|
||||
if (op->getNumResults() != 1)
|
||||
return failure();
|
||||
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
|
||||
if (!ot)
|
||||
return failure();
|
||||
type0 = ot;
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult optionalTensorResultType(Torch::ValueTensorType &type0) {
|
||||
if (op->getNumResults() != 1)
|
||||
return failure();
|
||||
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
|
||||
if (!ot)
|
||||
return failure();
|
||||
auto t = toValidTensorType(ot.getContainedType());
|
||||
if (!t)
|
||||
return failure();
|
||||
type0 = t;
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult optionalTensorListResultType(Torch::ListType &type0) {
|
||||
if (op->getNumResults() != 1)
|
||||
return failure();
|
||||
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
|
||||
if (!ot)
|
||||
return failure();
|
||||
auto tt = dyn_cast<Torch::ListType>(ot.getContainedType());
|
||||
if (!tt)
|
||||
return failure();
|
||||
if (!toValidTensorType(tt.getContainedType()))
|
||||
return failure();
|
||||
type0 = tt;
|
||||
return success();
|
||||
}
|
||||
|
||||
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||
// op.
|
||||
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||
|
|
|
@ -2672,4 +2672,63 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
/*cudnn_enabled=*/cstFalse);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Optional", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::OptionalType resultType;
|
||||
Value input;
|
||||
|
||||
if (binder.getNumOperands() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented support for missing input element");
|
||||
|
||||
if (binder.tensorListOperand(input))
|
||||
if (binder.tensorOperand(input))
|
||||
return failure();
|
||||
|
||||
if (binder.optionalResultType(resultType))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(binder.op, resultType,
|
||||
input);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("OptionalGetElement", 15,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ListType tensorListResultType;
|
||||
Torch::ValueTensorType tensorResultType;
|
||||
Value input;
|
||||
|
||||
if (binder.tensorListResultType(tensorListResultType)) {
|
||||
if (binder.tensorResultType(tensorResultType))
|
||||
return failure();
|
||||
|
||||
if (binder.optionalTensorOperand(input)) {
|
||||
if (binder.tensorOperand(input))
|
||||
return failure();
|
||||
|
||||
// It means the input is a tensor.
|
||||
rewriter.replaceOp(binder.op, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
// It means the input is an optional tensor.
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimUncheckedCastOp>(
|
||||
binder.op, tensorResultType, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (binder.optionalTensorListOperand(input)) {
|
||||
if (binder.tensorListOperand(input))
|
||||
return failure();
|
||||
|
||||
// It means the input is a tensor list.
|
||||
rewriter.replaceOp(binder.op, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
// It means the input is an optional tensor list.
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimUncheckedCastOp>(
|
||||
binder.op, tensorListResultType, input);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1495,3 +1495,51 @@ func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>
|
|||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional
|
||||
func.func @test_optional(%arg0: !torch.list<vtensor<[5],f32>>) -> !torch.optional<list<vtensor<[5],f32>>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64} {
|
||||
// CHECK: %[[RESULT:.*]] = torch.derefine %arg0 : !torch.list<vtensor<[5],f32>> to !torch.optional<list<vtensor<[5],f32>>>
|
||||
// CHECK: return %[[RESULT]] : !torch.optional<list<vtensor<[5],f32>>>
|
||||
%0 = torch.operator "onnx.Optional"(%arg0) : (!torch.list<vtensor<[5],f32>>) -> !torch.optional<list<vtensor<[5],f32>>>
|
||||
return %0 : !torch.optional<list<vtensor<[5],f32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_get_element_optional_sequence
|
||||
func.func @test_optional_get_element_optional_sequence(%arg0: !torch.optional<list<vtensor<[4],si32>>>) -> !torch.list<vtensor<[4],si32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional<list<vtensor<[4],si32>>> -> !torch.list<vtensor<[4],si32>>
|
||||
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[4],si32>>
|
||||
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional<list<vtensor<[4],si32>>>) -> !torch.list<vtensor<[4],si32>>
|
||||
return %0 : !torch.list<vtensor<[4],si32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_get_element_optional_tensor
|
||||
func.func @test_optional_get_element_optional_tensor(%arg0: !torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional<vtensor<[4],f32>> -> !torch.vtensor<[4],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[4],f32>
|
||||
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_get_element_sequence
|
||||
func.func @test_optional_get_element_sequence(%arg0: !torch.list<vtensor<[4],si32>>) -> !torch.list<vtensor<[4],si32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: return %arg0 : !torch.list<vtensor<[4],si32>>
|
||||
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.list<vtensor<[4],si32>>) -> !torch.list<vtensor<[4],si32>>
|
||||
return %0 : !torch.list<vtensor<[4],si32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_optional_get_element_tensor
|
||||
func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: return %arg0 : !torch.vtensor<[4],f32>
|
||||
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue