mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for Optional, OptionalGetElement op (#3467)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3470/head
parent
676fa8cc09
commit
822d763308
|
@ -45,6 +45,18 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult optionalTensorOperand(Value &value0) {
|
||||||
|
if (op->getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
value0 = op->getOperand(0);
|
||||||
|
auto ot = dyn_cast<Torch::OptionalType>(value0.getType());
|
||||||
|
if (!ot)
|
||||||
|
return failure();
|
||||||
|
if (!toValidTensorType(ot.getContainedType()))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult tensorOperands(Value &value0, Value &value1) {
|
ParseResult tensorOperands(Value &value0, Value &value1) {
|
||||||
if (op->getNumOperands() != 2)
|
if (op->getNumOperands() != 2)
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -110,6 +122,21 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult optionalTensorListOperand(Value &value0) {
|
||||||
|
if (op->getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
value0 = op->getOperand(0);
|
||||||
|
auto ot = dyn_cast<Torch::OptionalType>(value0.getType());
|
||||||
|
if (!ot)
|
||||||
|
return failure();
|
||||||
|
auto tt = dyn_cast<Torch::ListType>(ot.getContainedType());
|
||||||
|
if (!tt)
|
||||||
|
return failure();
|
||||||
|
if (!toValidTensorType(tt.getContainedType()))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
|
ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
|
||||||
if (idx >= op->getNumOperands())
|
if (idx >= op->getNumOperands())
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -144,6 +171,44 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult optionalResultType(Torch::OptionalType &type0) {
|
||||||
|
if (op->getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
|
||||||
|
if (!ot)
|
||||||
|
return failure();
|
||||||
|
type0 = ot;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult optionalTensorResultType(Torch::ValueTensorType &type0) {
|
||||||
|
if (op->getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
|
||||||
|
if (!ot)
|
||||||
|
return failure();
|
||||||
|
auto t = toValidTensorType(ot.getContainedType());
|
||||||
|
if (!t)
|
||||||
|
return failure();
|
||||||
|
type0 = t;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult optionalTensorListResultType(Torch::ListType &type0) {
|
||||||
|
if (op->getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
auto ot = dyn_cast<Torch::OptionalType>(op->getResult(0).getType());
|
||||||
|
if (!ot)
|
||||||
|
return failure();
|
||||||
|
auto tt = dyn_cast<Torch::ListType>(ot.getContainedType());
|
||||||
|
if (!tt)
|
||||||
|
return failure();
|
||||||
|
if (!toValidTensorType(tt.getContainedType()))
|
||||||
|
return failure();
|
||||||
|
type0 = tt;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
// The importer imports Onnx.GraphProto attributes as regions attached to the
|
||||||
// op.
|
// op.
|
||||||
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) {
|
||||||
|
|
|
@ -2672,4 +2672,63 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
/*cudnn_enabled=*/cstFalse);
|
/*cudnn_enabled=*/cstFalse);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Optional", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::OptionalType resultType;
|
||||||
|
Value input;
|
||||||
|
|
||||||
|
if (binder.getNumOperands() == 0)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented support for missing input element");
|
||||||
|
|
||||||
|
if (binder.tensorListOperand(input))
|
||||||
|
if (binder.tensorOperand(input))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (binder.optionalResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(binder.op, resultType,
|
||||||
|
input);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp("OptionalGetElement", 15,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ListType tensorListResultType;
|
||||||
|
Torch::ValueTensorType tensorResultType;
|
||||||
|
Value input;
|
||||||
|
|
||||||
|
if (binder.tensorListResultType(tensorListResultType)) {
|
||||||
|
if (binder.tensorResultType(tensorResultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (binder.optionalTensorOperand(input)) {
|
||||||
|
if (binder.tensorOperand(input))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// It means the input is a tensor.
|
||||||
|
rewriter.replaceOp(binder.op, input);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// It means the input is an optional tensor.
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimUncheckedCastOp>(
|
||||||
|
binder.op, tensorResultType, input);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (binder.optionalTensorListOperand(input)) {
|
||||||
|
if (binder.tensorListOperand(input))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// It means the input is a tensor list.
|
||||||
|
rewriter.replaceOp(binder.op, input);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// It means the input is an optional tensor list.
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimUncheckedCastOp>(
|
||||||
|
binder.op, tensorListResultType, input);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1495,3 +1495,51 @@ func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>
|
||||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_optional
|
||||||
|
func.func @test_optional(%arg0: !torch.list<vtensor<[5],f32>>) -> !torch.optional<list<vtensor<[5],f32>>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64} {
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.derefine %arg0 : !torch.list<vtensor<[5],f32>> to !torch.optional<list<vtensor<[5],f32>>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.optional<list<vtensor<[5],f32>>>
|
||||||
|
%0 = torch.operator "onnx.Optional"(%arg0) : (!torch.list<vtensor<[5],f32>>) -> !torch.optional<list<vtensor<[5],f32>>>
|
||||||
|
return %0 : !torch.optional<list<vtensor<[5],f32>>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_optional_get_element_optional_sequence
|
||||||
|
func.func @test_optional_get_element_optional_sequence(%arg0: !torch.optional<list<vtensor<[4],si32>>>) -> !torch.list<vtensor<[4],si32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional<list<vtensor<[4],si32>>> -> !torch.list<vtensor<[4],si32>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[4],si32>>
|
||||||
|
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional<list<vtensor<[4],si32>>>) -> !torch.list<vtensor<[4],si32>>
|
||||||
|
return %0 : !torch.list<vtensor<[4],si32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_optional_get_element_optional_tensor
|
||||||
|
func.func @test_optional_get_element_optional_tensor(%arg0: !torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional<vtensor<[4],f32>> -> !torch.vtensor<[4],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[4],f32>
|
||||||
|
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional<vtensor<[4],f32>>) -> !torch.vtensor<[4],f32>
|
||||||
|
return %0 : !torch.vtensor<[4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_optional_get_element_sequence
|
||||||
|
func.func @test_optional_get_element_sequence(%arg0: !torch.list<vtensor<[4],si32>>) -> !torch.list<vtensor<[4],si32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: return %arg0 : !torch.list<vtensor<[4],si32>>
|
||||||
|
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.list<vtensor<[4],si32>>) -> !torch.list<vtensor<[4],si32>>
|
||||||
|
return %0 : !torch.list<vtensor<[4],si32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_optional_get_element_tensor
|
||||||
|
func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: return %arg0 : !torch.vtensor<[4],f32>
|
||||||
|
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
|
||||||
|
return %0 : !torch.vtensor<[4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue