Add support for the onnx.SequenceLength op. (#3362)

pull/3268/merge
Andrew Woloszyn 2024-05-17 15:17:43 -04:00 committed by GitHub
parent 2937753070
commit 513d89c16d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 0 deletions

View File

@ -97,6 +97,19 @@ struct OpBinder {
return success();
}
// Operand matches of different arities.
ParseResult tensorListOperand(Value &value0) {
if (op->getNumOperands() != 1)
return failure();
value0 = op->getOperand(0);
auto tt = dyn_cast<Torch::ListType>(value0.getType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}
ParseResult tensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();

View File

@ -532,6 +532,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operands);
return success();
});
patterns.onOp(
"SequenceLength", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
// onnx.SequenceLength takes a sequence(list) of tensors, and returns
// a zero rank tensor with the length.
Torch::ValueTensorType resultType;
Value x;
if (binder.tensorListOperand(x) || binder.tensorResultType(resultType))
return failure();
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value len = rewriter.create<Torch::AtenLenTOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), x);
// AtenLenTOp returns a torch.int, so we have to
// put that in a tensor.
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
binder.op, resultType, len, none, none, cstFalse);
return success();
});
patterns.onOp(
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -2099,6 +2099,21 @@ module {
// -----
// CHECK-LABEL: func.func @test_sequence_length
module {
func.func @test_sequence_length(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64>
%0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
}
// -----
// CHECK-LABEL: func.func @test_sce_mean_3d
func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none