mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch support for SequenceMap (#3535)
parent
fde286f491
commit
f0ce1e94ce
|
@ -3359,6 +3359,87 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
rewriter.replaceOp(binder.op, inputSequence);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"SequenceMap", 17,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
llvm::SmallVector<Value> operands;
|
||||
Torch::ListType resultType;
|
||||
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
|
||||
binder.tensorListResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Region *bodyRegion;
|
||||
if (binder.getRegionAtIndex(bodyRegion, 0)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Failed getting Body Region");
|
||||
}
|
||||
|
||||
// construct an empty list, append results through the loop
|
||||
auto resultTensorType =
|
||||
dyn_cast<Torch::ValueTensorType>(resultType.getContainedType());
|
||||
Value shapeList = createConstantIntList(binder, rewriter,
|
||||
resultTensorType.getSizes());
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
|
||||
binder.op->getLoc(), resultType.getContainedType(), shapeList,
|
||||
/*dtype=*/cstNone, /*layout=*/cstNone, /*device=*/cstNone,
|
||||
/*pinMemory=*/cstNone, /*memoryFormat=*/cstNone);
|
||||
Value result = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(), resultType, llvm::SmallVector<Value>{self});
|
||||
|
||||
// create a for-like primLoopOp
|
||||
// with the length of sequence as max iter_num
|
||||
Value len = rewriter.create<Torch::AtenLenTOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[0]);
|
||||
auto cstTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(true));
|
||||
mlir::ImplicitLocOpBuilder b(binder.getLoc(), rewriter);
|
||||
auto loop =
|
||||
b.create<Torch::PrimLoopOp>(resultType, len, cstTrue, result);
|
||||
rewriter.cloneRegionBefore(*bodyRegion, loop.getRegion(),
|
||||
loop.getRegion().begin());
|
||||
|
||||
// primLoopOp loopBody expects torch.int as first arg
|
||||
// remove inputs from the region and use it from outside
|
||||
loop.getRegion().front().insertArgument(0U, resultType,
|
||||
binder.getLoc());
|
||||
Value sequenceArg = loop.getRegion().front().getArgument(0);
|
||||
loop.getRegion().front().insertArgument(
|
||||
0U, rewriter.getType<Torch::IntType>(), binder.getLoc());
|
||||
Value indexArg = loop.getRegion().front().getArgument(0);
|
||||
|
||||
// get sequence[i] (and addtionalInput[i]) in each iteration
|
||||
rewriter.setInsertionPointToStart(&loop.getRegion().front());
|
||||
for (size_t i = 0; i < operands.size(); i++) {
|
||||
Value argInput = loop.getRegion().front().getArgument(2);
|
||||
if (isa<Torch::ListType>(operands[i].getType())) {
|
||||
auto tensorType = dyn_cast<Torch::ValueTensorType>(
|
||||
dyn_cast<Torch::ListType>(operands[i].getType())
|
||||
.getContainedType());
|
||||
Value item = rewriter.create<Torch::Aten__Getitem__TOp>(
|
||||
binder.getLoc(), tensorType, operands[i], indexArg);
|
||||
argInput.replaceAllUsesWith(item);
|
||||
} else {
|
||||
argInput.replaceAllUsesWith(operands[i]);
|
||||
}
|
||||
loop.getRegion().eraseArgument(2);
|
||||
}
|
||||
|
||||
// replace terminator
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Operation *terminator = loop.getRegion().front().getTerminator();
|
||||
rewriter.setInsertionPoint(terminator);
|
||||
// update sequence input
|
||||
auto terminatorOperands = terminator->getOperands();
|
||||
Value append = rewriter.create<Torch::AtenAppendTOp>(
|
||||
binder.getLoc(), resultType, sequenceArg, terminatorOperands[0]);
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
|
||||
terminator, cstTrue, append);
|
||||
|
||||
rewriter.replaceOp(binder.op, loop);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Upsample", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -2575,6 +2575,124 @@ func.func @test_sequence_empty() -> !torch.list<vtensor<[],f32>> attributes {tor
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sequence_map_add
|
||||
func.func @test_sequence_map_add(%arg0: !torch.list<vtensor<[2,3],f32>>, %arg1: !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>>
|
||||
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[2,3],f32>> -> !torch.int
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
|
||||
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[2,3],f32>>):
|
||||
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[2,3],f32>>, !torch.int -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[SAMPLE]], %arg1, %[[C1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[ADD]] : !torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32> -> !torch.list<vtensor<[2,3],f32>>
|
||||
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[2,3],f32>>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[2,3],f32>>) -> !torch.list<vtensor<[2,3],f32>>
|
||||
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[2,3],f32>>
|
||||
%0 = torch.operator "onnx.SequenceMap"(%arg0, %arg1) : (!torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>> {
|
||||
^bb0(%arg2: !torch.vtensor<[2,3],f32>, %arg3: !torch.vtensor<[2,3],f32>):
|
||||
%1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
return %0 : !torch.list<vtensor<[2,3],f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sequence_map_add_sequence_variadic
|
||||
func.func @test_sequence_map_add_sequence_variadic(%arg0: !torch.list<vtensor<[?],f32>>, %arg1: !torch.list<vtensor<[?],f32>>, %arg2: !torch.vtensor<[?],f32>) -> !torch.list<vtensor<[?],f32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NEG1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[NEG1]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[?],f32>) -> !torch.list<vtensor<[?],f32>>
|
||||
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?],f32>> -> !torch.int
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
|
||||
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[?],f32>>):
|
||||
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[?],f32>>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[ADDITION_INPUT:.*]] = torch.aten.__getitem__.t %arg1, %[[ITER_NUM]] : !torch.list<vtensor<[?],f32>>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[SAMPLE]], %[[ADDITION_INPUT]], %[[C1]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.Tensor %[[ADD]], %arg2, %[[C1_0]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[ADD_0]] : !torch.list<vtensor<[?],f32>>, !torch.vtensor<[?],f32> -> !torch.list<vtensor<[?],f32>>
|
||||
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[?],f32>>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[?],f32>>) -> !torch.list<vtensor<[?],f32>>
|
||||
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[?],f32>>
|
||||
%0 = torch.operator "onnx.SequenceMap"(%arg0, %arg1, %arg2) : (!torch.list<vtensor<[?],f32>>, !torch.list<vtensor<[?],f32>>, !torch.vtensor<[?],f32>) -> !torch.list<vtensor<[?],f32>> {
|
||||
^bb0(%arg3: !torch.vtensor<[?],f32>, %arg4: !torch.vtensor<[?],f32>, %arg5: !torch.vtensor<[?],f32>):
|
||||
%1 = torch.operator "onnx.Add"(%arg3, %arg4) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32>
|
||||
%2 = torch.operator "onnx.Add"(%1, %arg5) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32>
|
||||
torch.operator_terminator %2 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
return %0 : !torch.list<vtensor<[?],f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sequence_map_identity
|
||||
func.func @test_sequence_map_identity(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[?,?,?],f32>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NEG1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[NEG1_0:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[NEG1_1:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[NEG1]], %[[NEG1_0]], %[[NEG1_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[?,?,?],f32>) -> !torch.list<vtensor<[?,?,?],f32>>
|
||||
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
|
||||
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[?,?,?],f32>>):
|
||||
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[?,?,?],f32>>, !torch.int -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||
// CHECK: %[[CLONE:.*]] = torch.aten.clone %[[SAMPLE]], %[[NONE_0]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[CLONE]] : !torch.list<vtensor<[?,?,?],f32>>, !torch.vtensor<[?,?,?],f32> -> !torch.list<vtensor<[?,?,?],f32>>
|
||||
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[?,?,?],f32>>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[?,?,?],f32>>
|
||||
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[?,?,?],f32>>
|
||||
%0 = torch.operator "onnx.SequenceMap"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[?,?,?],f32>> {
|
||||
^bb0(%arg1: !torch.vtensor<[?,?,?],f32>):
|
||||
%1 = torch.operator "onnx.Identity"(%arg1) : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
return %0 : !torch.list<vtensor<[?,?,?],f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sequence_map_extract_shapes
|
||||
func.func @test_sequence_map_extract_shapes(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[3],si64>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[SHAPE]] = torch.prim.ListConstruct %[[C3]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor<[3],si64>>
|
||||
// CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) {
|
||||
// CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list<vtensor<[3],si64>>):
|
||||
// CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list<vtensor<[?,?,?],f32>>, !torch.int -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[SHAPE_0:.*]] = torch.aten._shape_as_tensor %[[SAMPLE]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[3],si64>
|
||||
// CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[SHAPE_0]] : !torch.list<vtensor<[3],si64>>, !torch.vtensor<[3],si64> -> !torch.list<vtensor<[3],si64>>
|
||||
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list<vtensor<[3],si64>>)
|
||||
// CHECK: } : (!torch.int, !torch.bool, !torch.list<vtensor<[3],si64>>) -> !torch.list<vtensor<[3],si64>>
|
||||
// CHECK: return %[[LOOP]] : !torch.list<vtensor<[3],si64>>
|
||||
%0 = torch.operator "onnx.SequenceMap"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.list<vtensor<[3],si64>> {
|
||||
^bb0(%arg1: !torch.vtensor<[?,?,?],f32>):
|
||||
%1 = torch.operator "onnx.Shape"(%arg1) : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3],si64>
|
||||
torch.operator_terminator %1 : !torch.vtensor<[3],si64>
|
||||
}
|
||||
return %0 : !torch.list<vtensor<[3],si64>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_upsample_nearest
|
||||
func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue