From d35b6b412aa7252eb377967f4feb2a753ec1a7fb Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Sat, 8 Jun 2024 09:58:11 +0530 Subject: [PATCH] [ONNX] Add OnnxToTorch Lowering for Sequence Ops (#3425) This commit adds the lowering for SequenceAt, SequenceEmpty, SequenceInsert, SequenceErase op Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchOnnxToTorch/Patterns.h | 12 ++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 138 +++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 189 ++++++++++++++++++ 3 files changed, 339 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 0de85f4ee..f296b6dfe 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -110,6 +110,18 @@ struct OpBinder { return success(); } + ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) { + if (idx >= op->getNumOperands()) + return failure(); + valueIdx = op->getOperand(idx); + auto tt = dyn_cast(valueIdx.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorListResultType(Torch::ListType &type0) { if (op->getNumResults() != 1) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1eb5bcc1c..18399aa2d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3120,7 +3120,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOpWithNewOp( binder.op, resultType, inputLTNegLambd, inputPlusBias, inputSubBiasOrZero); + return success(); + }); + patterns.onOp("SequenceAt", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(position, 1) || + binder.tensorResultType(resultType)) + return failure(); + Value index = rewriter.create( + binder.getLoc(), rewriter.getType(), + position); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, index); + return success(); + }); + patterns.onOp( + "SequenceEmpty", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + int64_t dtypeIntOnnx; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.tensorListResultType(resultType)) + return failure(); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, {}); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType.getContainedType(), shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, llvm::SmallVector{self}); + return success(); + }); + patterns.onOp( + "SequenceErase", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorListResultType(resultType)) + return failure(); + + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), inputSequence); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the last tensor from the list has to be erased. + Value lengthMinusOne = rewriter.create( + binder.getLoc(), length, cstOne); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, /*start=*/cstNone, + /*end=*/lengthMinusOne, /*step=*/cstOne); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 1)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + // Handling negative position value. + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value isPositionNegative = rewriter.create( + binder.getLoc(), positionInt, cstZero); + isPositionNegative = rewriter.create( + binder.getLoc(), isPositionNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isPositionNegative, length); + positionInt = rewriter.create( + binder.getLoc(), positionInt, finalOffset); + + Value listBeforePosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, /*start=*/cstNone, + /*end=*/positionInt, /*step=*/cstOne); + Value positionPlusOne = rewriter.create( + binder.getLoc(), positionInt, cstOne); + Value listAfterPosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, + /*start=*/positionPlusOne, + /*end=*/length, /*step=*/cstOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, listBeforePosition, listAfterPosition); + return success(); + }); + patterns.onOp( + "SequenceInsert", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position, insertValue; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(insertValue, 1) || + binder.tensorListResultType(resultType)) + return failure(); + + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the tensor has to be inserted at the end of the list. + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputSequence); + rewriter.replaceOpWithNewOp( + binder.op, inputSequence, /*idx=*/length, + /*el=*/insertValue); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 2)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + rewriter.create(binder.getLoc(), inputSequence, + /*idx=*/positionInt, + /*el=*/insertValue); + rewriter.replaceOp(binder.op, inputSequence); return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index eb5a9f7ca..317a3aeb1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2339,3 +2339,192 @@ func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5 %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> return %0 : !torch.vtensor<[5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_sequence_at +func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %4 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_insert +func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list>, !torch.int, !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list> + %6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %6 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_beginning +func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_end +func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_negative_idx +func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_empty +func.func @test_sequence_erase_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_empty +func.func @test_sequence_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + return %0 : !torch.list> +}