[ONNX] Add OnnxToTorch Lowering for Sequence Ops (#3425)

This commit adds the lowering for SequenceAt, SequenceEmpty,
SequenceInsert, SequenceErase op

Signed-Off By: Vivek Khandelwal<vivekkhandelwal1424@gmail.com>
pull/3437/head
Vivek Khandelwal 2024-06-08 09:58:11 +05:30 committed by GitHub
parent 689efc8917
commit d35b6b412a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 339 additions and 0 deletions

View File

@ -110,6 +110,18 @@ struct OpBinder {
return success();
}
ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
if (idx >= op->getNumOperands())
return failure();
valueIdx = op->getOperand(idx);
auto tt = dyn_cast<Torch::ListType>(valueIdx.getType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}
ParseResult tensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();

View File

@ -3120,7 +3120,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
binder.op, resultType, inputLTNegLambd, inputPlusBias,
inputSubBiasOrZero);
return success();
});
patterns.onOp("SequenceAt", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value inputSequence, position;
if (binder.tensorListOperandAtIndex(inputSequence, 0) ||
binder.tensorOperandAtIndex(position, 1) ||
binder.tensorResultType(resultType))
return failure();
Value index = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
position);
rewriter.replaceOpWithNewOp<Torch::Aten__Getitem__TOp>(
binder.op, resultType, inputSequence, index);
return success();
});
patterns.onOp(
"SequenceEmpty", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ListType resultType;
int64_t dtypeIntOnnx;
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.tensorListResultType(resultType))
return failure();
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value shapeList = createConstantIntList(binder, rewriter, {});
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
binder.op->getLoc(), resultType.getContainedType(), shapeList,
/*dtype=*/constDtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
binder.op, resultType, llvm::SmallVector<Value>{self});
return success();
});
patterns.onOp(
"SequenceErase", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ListType resultType;
Value inputSequence, position;
if (binder.tensorListOperandAtIndex(inputSequence, 0) ||
binder.tensorListResultType(resultType))
return failure();
Value length = rewriter.create<Torch::AtenLenTOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), inputSequence);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
if (binder.op->getNumOperands() == 1) {
// If True, it means that the `position` arg is missing and
// the last tensor from the list has to be erased.
Value lengthMinusOne = rewriter.create<Torch::AtenSubIntOp>(
binder.getLoc(), length, cstOne);
rewriter.replaceOpWithNewOp<Torch::AtenSliceTOp>(
binder.op, resultType, inputSequence, /*start=*/cstNone,
/*end=*/lengthMinusOne, /*step=*/cstOne);
return success();
}
if (binder.tensorOperandAtIndex(position, 1))
return failure();
Value positionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), position);
// Handling negative position value.
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value isPositionNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), positionInt, cstZero);
isPositionNegative = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), isPositionNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isPositionNegative, length);
positionInt = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), positionInt, finalOffset);
Value listBeforePosition = rewriter.create<Torch::AtenSliceTOp>(
binder.getLoc(), resultType, inputSequence, /*start=*/cstNone,
/*end=*/positionInt, /*step=*/cstOne);
Value positionPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), positionInt, cstOne);
Value listAfterPosition = rewriter.create<Torch::AtenSliceTOp>(
binder.getLoc(), resultType, inputSequence,
/*start=*/positionPlusOne,
/*end=*/length, /*step=*/cstOne);
rewriter.replaceOpWithNewOp<Torch::AtenAddTOp>(
binder.op, resultType, listBeforePosition, listAfterPosition);
return success();
});
patterns.onOp(
"SequenceInsert", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ListType resultType;
Value inputSequence, position, insertValue;
if (binder.tensorListOperandAtIndex(inputSequence, 0) ||
binder.tensorOperandAtIndex(insertValue, 1) ||
binder.tensorListResultType(resultType))
return failure();
if (binder.op->getNumOperands() == 1) {
// If True, it means that the `position` arg is missing and
// the tensor has to be inserted at the end of the list.
Value length = rewriter.create<Torch::AtenLenTOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
inputSequence);
rewriter.replaceOpWithNewOp<Torch::AtenInsertTOp>(
binder.op, inputSequence, /*idx=*/length,
/*el=*/insertValue);
return success();
}
if (binder.tensorOperandAtIndex(position, 2))
return failure();
Value positionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), position);
rewriter.create<Torch::AtenInsertTOp>(binder.getLoc(), inputSequence,
/*idx=*/positionInt,
/*el=*/insertValue);
rewriter.replaceOp(binder.op, inputSequence);
return success();
});
}

View File

@ -2339,3 +2339,192 @@ func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5
%0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32>
return %0 : !torch.vtensor<[5],f32>
}
// -----
// CHECK-LABEL: func.func @test_sequence_at
func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int -> !torch.vtensor<[2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
%3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
%4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32>
return %4 : !torch.vtensor<[2,3,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_sequence_insert
func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.vtensor<[2,3,4],f32>
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int -> !torch.vtensor<[2,3,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
%5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
%6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32>
return %6 : !torch.vtensor<[2,3,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_sequence_erase_at_beginning
func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[2,3,4],f32>>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
return %4 : !torch.list<vtensor<[2,3,4],f32>>
}
// -----
// CHECK-LABEL: func.func @test_sequence_erase_at_end
func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[2,3,4],f32>>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
return %4 : !torch.list<vtensor<[2,3,4],f32>>
}
// -----
// CHECK-LABEL: func.func @test_sequence_erase_negative_idx
func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[2,3,4],f32>>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
return %4 : !torch.list<vtensor<[2,3,4],f32>>
}
// -----
// CHECK-LABEL: func.func @test_sequence_erase_empty
func.func @test_sequence_erase_empty() -> !torch.list<vtensor<[],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list<vtensor<[],f32>>
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[],f32>> -> !torch.int
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[],f32>>
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[],f32>>
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[],f32>>, !torch.list<vtensor<[],f32>> -> !torch.list<vtensor<[],f32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[],f32>>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list<vtensor<[],f32>>
%4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list<vtensor<[],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[],f32>>
return %4 : !torch.list<vtensor<[],f32>>
}
// -----
// CHECK-LABEL: func.func @test_sequence_empty
func.func @test_sequence_empty() -> !torch.list<vtensor<[],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list<vtensor<[],f32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[],f32>>
%0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list<vtensor<[],f32>>
return %0 : !torch.list<vtensor<[],f32>>
}