mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch support for ReverseSequence (#3495)
parent
39d1332008
commit
6d0ca499e6
|
@ -3564,4 +3564,82 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, permutedStft);
|
binder.op, resultType, permutedStft);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"ReverseSequence", 10,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input, sequenceLens;
|
||||||
|
int64_t batchAxis, timeAxis;
|
||||||
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||||
|
binder.tensorOperandAtIndex(sequenceLens, 1) ||
|
||||||
|
binder.s64IntegerAttr(batchAxis, "batch_axis", 1) ||
|
||||||
|
binder.s64IntegerAttr(timeAxis, "time_axis", 0) ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||||
|
SmallVector<int64_t> inputShape(inputTy.getSizes());
|
||||||
|
auto dtype = resultType.getDtype();
|
||||||
|
|
||||||
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||||
|
Value batchAxisVal = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(batchAxis));
|
||||||
|
Value timeAxisVal = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(timeAxis));
|
||||||
|
|
||||||
|
SmallVector<int64_t> sliceShape(inputShape);
|
||||||
|
sliceShape[batchAxis] = 1;
|
||||||
|
auto sliceType =
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(sliceShape, dtype);
|
||||||
|
SmallVector<int64_t> flipShape(sliceShape);
|
||||||
|
flipShape[timeAxis] = Torch::kUnknownSize;
|
||||||
|
auto flipType =
|
||||||
|
rewriter.getType<Torch::ValueTensorType>(flipShape, dtype);
|
||||||
|
auto scalarTensorType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
ArrayRef<int64_t>{1}, rewriter.getIntegerType(64, /*signed*/ 1));
|
||||||
|
|
||||||
|
for (int i = 0; i < inputShape[batchAxis]; i++) {
|
||||||
|
// slice i iterating on batch axis
|
||||||
|
Value k = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
||||||
|
Value end =
|
||||||
|
rewriter.create<Torch::AtenAddIntOp>(binder.getLoc(), k, cstOne);
|
||||||
|
Value sliceBatch = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||||
|
binder.getLoc(), sliceType, input, batchAxisVal, k, end, cstOne);
|
||||||
|
|
||||||
|
// get sequence length and slice the reversing part
|
||||||
|
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||||
|
binder.getLoc(), scalarTensorType, k);
|
||||||
|
Value sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||||
|
binder.getLoc(), scalarTensorType, sequenceLens, cstZero,
|
||||||
|
kTensor);
|
||||||
|
Value len = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), sel);
|
||||||
|
Value sliceTime = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||||
|
binder.getLoc(), flipType, sliceBatch, timeAxisVal, cstZero, len,
|
||||||
|
cstOne);
|
||||||
|
// flip the sliced reversing tensor
|
||||||
|
Value dims = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getType<Torch::ListType>(
|
||||||
|
rewriter.getType<Torch::IntType>()),
|
||||||
|
SmallVector<Value>{timeAxisVal});
|
||||||
|
Value flip = rewriter.create<Torch::AtenFlipOp>(
|
||||||
|
binder.getLoc(), flipType, sliceTime, dims);
|
||||||
|
|
||||||
|
// embeds the reversed tensor to the input
|
||||||
|
Value embedTime = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||||
|
binder.getLoc(), sliceType, sliceBatch, flip, timeAxisVal,
|
||||||
|
/*start=*/cstZero, /*end=*/len, /*step=*/cstOne);
|
||||||
|
input = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||||
|
binder.getLoc(), resultType, input, embedTime, batchAxisVal,
|
||||||
|
/*start=*/k, /*end=*/end, /*step=*/cstOne);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(binder.op, input);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -2663,3 +2663,115 @@ func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !t
|
||||||
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32>
|
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32>
|
||||||
return %0 : !torch.vtensor<[1,15,9,2],f32>
|
return %0 : !torch.vtensor<[1,15,9,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reversesequence_batch
|
||||||
|
func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[C0_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %[[SLICE]], %[[C1_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP:.*]] = torch.aten.flip %[[SLICE_0]], %[[DIM]] : !torch.vtensor<[1,?],f32>, !torch.list<int> -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[EMBED:.*]] = torch.aten.slice_scatter %[[SLICE]], %[[FLIP]], %[[C1_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[EMBED_0:.*]] = torch.aten.slice_scatter %arg0, %[[EMBED]], %[[C0_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[C1_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.Tensor %[[EMBED_0]], %[[C0_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_2:.*]] = torch.aten.slice.Tensor %[[SLICE_1]], %[[C1_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[DIM_0:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[SLICE_2]], %[[DIM_0]] : !torch.vtensor<[1,?],f32>, !torch.list<int> -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[EMBED_1:.*]] = torch.aten.slice_scatter %[[SLICE_1]], %[[FLIP_0]], %[[C1_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[EMBED_2:.*]] = torch.aten.slice_scatter %[[EMBED_0]], %[[EMBED_1]], %[[C0_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[C2]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_3:.*]] = torch.aten.slice.Tensor %[[EMBED_2]], %[[C0_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[INDEX_1:.*]] = torch.prim.NumToTensor.Scalar %[[C2]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT_1:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_4:.*]] = torch.aten.slice.Tensor %[[SLICE_3]], %[[C1_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[DIM_1:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP_1:.*]] = torch.aten.flip %[[SLICE_4]], %[[DIM_1]] : !torch.vtensor<[1,?],f32>, !torch.list<int> -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[EMBED_3:.*]] = torch.aten.slice_scatter %[[SLICE_3]], %[[FLIP_1]], %[[C1_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[EMBED_4:.*]] = torch.aten.slice_scatter %[[EMBED_2]], %[[EMBED_3]], %[[C0_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[ADD_2:.*]] = torch.aten.add.int %[[C3]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_5:.*]] = torch.aten.slice.Tensor %[[EMBED_4]], %[[C0_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: %[[INDEX_2:.*]] = torch.prim.NumToTensor.Scalar %[[C3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT_2:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_6:.*]] = torch.aten.slice.Tensor %[[SLICE_5]], %[[C1_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[DIM_2:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP_2:.*]] = torch.aten.flip %[[SLICE_6]], %[[DIM_2]] : !torch.vtensor<[1,?],f32>, !torch.list<int> -> !torch.vtensor<[1,?],f32>
|
||||||
|
// CHECK: %[[EMBED_5:.*]] = torch.aten.slice_scatter %[[SLICE_5]], %[[FLIP_2]], %[[C1_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32>
|
||||||
|
// CHECK: torch.aten.slice_scatter %[[EMBED_4]], %[[EMBED_5]], %[[C0_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
%0 = torch.operator "onnx.ReverseSequence"(%arg0, %arg1) {torch.onnx.batch_axis = 0 : si64, torch.onnx.time_axis = 1 : si64} : (!torch.vtensor<[4,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_reversesequence_time
|
||||||
|
func.func @test_reversesequence_time(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[C0_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %[[SLICE]], %[[C0_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[DIM:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP:.*]] = torch.aten.flip %[[SLICE_0]], %[[DIM]] : !torch.vtensor<[?,1],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[EMBED:.*]] = torch.aten.slice_scatter %[[SLICE]], %[[FLIP]], %[[C0_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[EMBED_0:.*]] = torch.aten.slice_scatter %arg0, %[[EMBED]], %[[C1_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[C1_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_1:.*]] = torch.aten.slice.Tensor %[[EMBED_0]], %[[C1_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_2:.*]] = torch.aten.slice.Tensor %[[SLICE_1]], %[[C0_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[DIM_0:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[SLICE_2]], %[[DIM_0]] : !torch.vtensor<[?,1],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[EMBED_1:.*]] = torch.aten.slice_scatter %[[SLICE_1]], %[[FLIP_0]], %[[C0_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[EMBED_2:.*]] = torch.aten.slice_scatter %[[EMBED_0]], %[[EMBED_1]], %[[C1_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[C2]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_3:.*]] = torch.aten.slice.Tensor %[[EMBED_2]], %[[C1_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[INDEX_1:.*]] = torch.prim.NumToTensor.Scalar %[[C2]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT_1:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_4:.*]] = torch.aten.slice.Tensor %[[SLICE_3]], %[[C0_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[DIM_1:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP_1:.*]] = torch.aten.flip %[[SLICE_4]], %[[DIM_1]] : !torch.vtensor<[?,1],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[EMBED_3:.*]] = torch.aten.slice_scatter %[[SLICE_3]], %[[FLIP_1]], %[[C0_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[EMBED_4:.*]] = torch.aten.slice_scatter %[[EMBED_2]], %[[EMBED_3]], %[[C1_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[ADD_2:.*]] = torch.aten.add.int %[[C3]], %[[C1]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_5:.*]] = torch.aten.slice.Tensor %[[EMBED_4]], %[[C1_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: %[[INDEX_2:.*]] = torch.prim.NumToTensor.Scalar %[[C3]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT_2:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SLICE_6:.*]] = torch.aten.slice.Tensor %[[SLICE_5]], %[[C0_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[DIM_2:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[FLIP_2:.*]] = torch.aten.flip %[[SLICE_6]], %[[DIM_2]] : !torch.vtensor<[?,1],f32>, !torch.list<int> -> !torch.vtensor<[?,1],f32>
|
||||||
|
// CHECK: %[[EMBED_5:.*]] = torch.aten.slice_scatter %[[SLICE_5]], %[[FLIP_2]], %[[C0_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32>
|
||||||
|
// CHECK: torch.aten.slice_scatter %[[EMBED_4]], %[[EMBED_5]], %[[C1_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32>
|
||||||
|
%0 = torch.operator "onnx.ReverseSequence"(%arg0, %arg1) {torch.onnx.batch_axis = 1 : si64, torch.onnx.time_axis = 0 : si64} : (!torch.vtensor<[4,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,4],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue