mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch Lowering for Sequence Ops (#3425)
This commit adds the lowering for SequenceAt, SequenceEmpty, SequenceInsert, SequenceErase op Signed-Off By: Vivek Khandelwal<vivekkhandelwal1424@gmail.com>pull/3437/head
parent
689efc8917
commit
d35b6b412a
|
@ -110,6 +110,18 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
|
||||||
|
if (idx >= op->getNumOperands())
|
||||||
|
return failure();
|
||||||
|
valueIdx = op->getOperand(idx);
|
||||||
|
auto tt = dyn_cast<Torch::ListType>(valueIdx.getType());
|
||||||
|
if (!tt)
|
||||||
|
return failure();
|
||||||
|
if (!toValidTensorType(tt.getContainedType()))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult tensorListResultType(Torch::ListType &type0) {
|
ParseResult tensorListResultType(Torch::ListType &type0) {
|
||||||
if (op->getNumResults() != 1)
|
if (op->getNumResults() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -3120,7 +3120,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenWhereSelfOp>(
|
||||||
binder.op, resultType, inputLTNegLambd, inputPlusBias,
|
binder.op, resultType, inputLTNegLambd, inputPlusBias,
|
||||||
inputSubBiasOrZero);
|
inputSubBiasOrZero);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp("SequenceAt", 11,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value inputSequence, position;
|
||||||
|
if (binder.tensorListOperandAtIndex(inputSequence, 0) ||
|
||||||
|
binder.tensorOperandAtIndex(position, 1) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value index = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
position);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::Aten__Getitem__TOp>(
|
||||||
|
binder.op, resultType, inputSequence, index);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"SequenceEmpty", 11,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ListType resultType;
|
||||||
|
int64_t dtypeIntOnnx;
|
||||||
|
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
|
||||||
|
binder.tensorListResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
std::optional<int64_t> dtypeIntTorch =
|
||||||
|
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
|
||||||
|
if (!dtypeIntTorch.has_value()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op,
|
||||||
|
"unimplemented support for the given dtype conversion");
|
||||||
|
}
|
||||||
|
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
|
||||||
|
|
||||||
|
Value shapeList = createConstantIntList(binder, rewriter, {});
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
|
||||||
|
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
|
||||||
|
binder.op->getLoc(), resultType.getContainedType(), shapeList,
|
||||||
|
/*dtype=*/constDtype,
|
||||||
|
/*layout=*/cstNone,
|
||||||
|
/*device=*/cstNone, /*pinMemory=*/cstNone,
|
||||||
|
/*memoryFormat=*/cstNone);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
||||||
|
binder.op, resultType, llvm::SmallVector<Value>{self});
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"SequenceErase", 11,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ListType resultType;
|
||||||
|
Value inputSequence, position;
|
||||||
|
if (binder.tensorListOperandAtIndex(inputSequence, 0) ||
|
||||||
|
binder.tensorListResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value length = rewriter.create<Torch::AtenLenTOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), inputSequence);
|
||||||
|
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||||
|
if (binder.op->getNumOperands() == 1) {
|
||||||
|
// If True, it means that the `position` arg is missing and
|
||||||
|
// the last tensor from the list has to be erased.
|
||||||
|
Value lengthMinusOne = rewriter.create<Torch::AtenSubIntOp>(
|
||||||
|
binder.getLoc(), length, cstOne);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenSliceTOp>(
|
||||||
|
binder.op, resultType, inputSequence, /*start=*/cstNone,
|
||||||
|
/*end=*/lengthMinusOne, /*step=*/cstOne);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (binder.tensorOperandAtIndex(position, 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value positionInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), position);
|
||||||
|
// Handling negative position value.
|
||||||
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
Value isPositionNegative = rewriter.create<Torch::AtenLtIntOp>(
|
||||||
|
binder.getLoc(), positionInt, cstZero);
|
||||||
|
isPositionNegative = rewriter.create<Torch::AtenIntBoolOp>(
|
||||||
|
binder.getLoc(), isPositionNegative);
|
||||||
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
|
binder.getLoc(), isPositionNegative, length);
|
||||||
|
positionInt = rewriter.create<Torch::AtenAddIntOp>(
|
||||||
|
binder.getLoc(), positionInt, finalOffset);
|
||||||
|
|
||||||
|
Value listBeforePosition = rewriter.create<Torch::AtenSliceTOp>(
|
||||||
|
binder.getLoc(), resultType, inputSequence, /*start=*/cstNone,
|
||||||
|
/*end=*/positionInt, /*step=*/cstOne);
|
||||||
|
Value positionPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||||
|
binder.getLoc(), positionInt, cstOne);
|
||||||
|
Value listAfterPosition = rewriter.create<Torch::AtenSliceTOp>(
|
||||||
|
binder.getLoc(), resultType, inputSequence,
|
||||||
|
/*start=*/positionPlusOne,
|
||||||
|
/*end=*/length, /*step=*/cstOne);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenAddTOp>(
|
||||||
|
binder.op, resultType, listBeforePosition, listAfterPosition);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"SequenceInsert", 11,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ListType resultType;
|
||||||
|
Value inputSequence, position, insertValue;
|
||||||
|
if (binder.tensorListOperandAtIndex(inputSequence, 0) ||
|
||||||
|
binder.tensorOperandAtIndex(insertValue, 1) ||
|
||||||
|
binder.tensorListResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (binder.op->getNumOperands() == 1) {
|
||||||
|
// If True, it means that the `position` arg is missing and
|
||||||
|
// the tensor has to be inserted at the end of the list.
|
||||||
|
Value length = rewriter.create<Torch::AtenLenTOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
inputSequence);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenInsertTOp>(
|
||||||
|
binder.op, inputSequence, /*idx=*/length,
|
||||||
|
/*el=*/insertValue);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (binder.tensorOperandAtIndex(position, 2))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value positionInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), position);
|
||||||
|
rewriter.create<Torch::AtenInsertTOp>(binder.getLoc(), inputSequence,
|
||||||
|
/*idx=*/positionInt,
|
||||||
|
/*el=*/insertValue);
|
||||||
|
rewriter.replaceOp(binder.op, inputSequence);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -2339,3 +2339,192 @@ func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5
|
||||||
%0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32>
|
%0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32>
|
||||||
return %0 : !torch.vtensor<[5],f32>
|
return %0 : !torch.vtensor<[5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_at
|
||||||
|
func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int -> !torch.vtensor<[2,3,4],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32>
|
||||||
|
return %4 : !torch.vtensor<[2,3,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_insert
|
||||||
|
func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.vtensor<[2,3,4],f32>
|
||||||
|
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int -> !torch.vtensor<[2,3,4],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32>
|
||||||
|
return %6 : !torch.vtensor<[2,3,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_erase_at_beginning
|
||||||
|
func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
return %4 : !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_erase_at_end
|
||||||
|
func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
return %4 : !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_erase_negative_idx
|
||||||
|
func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[2,3,4],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[2,3,4],f32>>, !torch.list<vtensor<[2,3,4],f32>> -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
%4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list<vtensor<[2,3,4],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
return %4 : !torch.list<vtensor<[2,3,4],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_erase_empty
|
||||||
|
func.func @test_sequence_erase_empty() -> !torch.list<vtensor<[],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list<vtensor<[],f32>>
|
||||||
|
// CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list<vtensor<[],f32>> -> !torch.int
|
||||||
|
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool
|
||||||
|
// CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int
|
||||||
|
// CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list<vtensor<[],f32>>, !torch.none, !torch.int, !torch.int -> !torch.list<vtensor<[],f32>>
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list<vtensor<[],f32>>, !torch.int, !torch.int, !torch.int -> !torch.list<vtensor<[],f32>>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list<vtensor<[],f32>>, !torch.list<vtensor<[],f32>> -> !torch.list<vtensor<[],f32>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[],f32>>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list<vtensor<[],f32>>
|
||||||
|
%4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list<vtensor<[],f32>>, !torch.vtensor<[],si64>) -> !torch.list<vtensor<[],f32>>
|
||||||
|
return %4 : !torch.list<vtensor<[],f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_sequence_empty
|
||||||
|
func.func @test_sequence_empty() -> !torch.list<vtensor<[],f32>> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list<vtensor<[],f32>>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[],f32>>
|
||||||
|
%0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list<vtensor<[],f32>>
|
||||||
|
return %0 : !torch.list<vtensor<[],f32>>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue