diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 86f2455ca..7d7d588ad 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -4016,4 +4016,83 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, scatter, constZero, unflattenSizeList); return success(); }); + // split to sequence + // Arguments: + // - input: the tensor to split + // -Split(optional): Length of each output + // Attributes: + // - axis: the axis along which to split the input + // - keepdims: to keep the split dimension or not. Ignored when 'split' is + // specified Outputs: + // - outputs: sequence of tensor + // + + patterns.onOp( + "SplitToSequence", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value self; + Value split; + int64_t axis; + int64_t keepdims; + Torch::ListType resultType; + + if (binder.op->getNumOperands() == 1) + return rewriter.notifyMatchFailure( + binder.op, "No of operands should be two.Keepdims attribute is " + "not yet implemented"); + + if (binder.tensorOperandAtIndex(self, 0) || + binder.tensorListResultType(resultType) || + binder.s64IntegerAttr(keepdims, "keepdims", 1) || + binder.tensorOperandAtIndex(split, 1) || + binder.s64IntegerAttr(axis, "axis", 0)) + return rewriter.notifyMatchFailure( + binder.op, + "Not converting to AtenSplitToSequenceOp due to inputs "); + + Value axisValue = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axis)); + auto splitTy = cast(split.getType()); + + if (!splitTy || !splitTy.hasSizes()) + return failure(); + + auto splitSizes = splitTy.getSizes(); + unsigned splitDim = splitTy.getSizes().size(); + + if (splitDim > 1) + return rewriter.notifyMatchFailure( + binder.op, "Split should be scalar or 1-D Tensor "); + + if (splitDim == 1) { + if (splitSizes[0] == Torch::kUnknownSize) { + return rewriter.notifyMatchFailure( + binder.op, "Dynamic shapes for Split is not yet supported"); + } else if (splitSizes[0] <= + 1) { // dealing with 1/0 element in 1-D tensor + Value splitInt = rewriter.create( + binder.getLoc(), rewriter.getType(), split); + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, splitInt, axisValue); + return success(); + } else { + // Handling multiple elment in split + Value shapeList = + createConstantIntList(binder, rewriter, splitSizes); + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, shapeList, axisValue); + return success(); + } + } else if (splitDim == 0) { // Handle 0-D tensor + Value splitInt = rewriter.create( + binder.getLoc(), rewriter.getType(), split); + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, splitInt, axisValue); + return success(); + } else { + return rewriter.notifyMatchFailure( + binder.op, "Handling of this kind of inputs is not there"); + } + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index a9b6b7c66..6541f6f55 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -3108,3 +3108,56 @@ func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch. %0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) {torch.onnx.reduction = "min"} : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> return %0 : !torch.vtensor<[4,4,4],f32> } + +// ---- + +// CHECK-LABEL: func.func @test_split_to_sequence_1 +func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]]: !torch.vtensor<[3,6],f32> + // CHECK: %[[VAL_1:.*]]: !torch.vtensor<[1],si64>) -> !torch.list> + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_5:.*]] = torch.aten.split.Tensor %[[VAL_0]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,6],f32>, !torch.int, !torch.int -> !torch.list> + // CHECK: return %[[VAL_5]] : !torch.list> + %none = torch.constant.none + %int1 = torch.constant.int 1 + %0 = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int + %1 = torch.aten.split.Tensor %arg0, %0, %int1 : !torch.vtensor<[3,6],f32>, !torch.int, !torch.int -> !torch.list> + return %1 : !torch.list> +} + +// ---- + +// CHECK-LABEL: func.func @test_split_to_sequence_2 +func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]]: !torch.vtensor<[2,6],f32> + // CHECK: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.list> + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL_5:.*]] = torch.aten.split.Tensor %[[VAL_0]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[2,6],f32>, !torch.int, !torch.int -> !torch.list> + // CHECK: return %[[VAL_5]] : !torch.list> + %none = torch.constant.none + %int0 = torch.constant.int 0 + %0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + %1 = torch.aten.split.Tensor %arg0, %0, %int0 : !torch.vtensor<[2,6],f32>, !torch.int, !torch.int -> !torch.list> + return %1 : !torch.list> +} + +// ---- + +// CHECK-LABEL: func.func @test_split_to_sequence_with_list( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.split.sizes %[[VAL_0]], %[[VAL_5]], %[[VAL_3]] : !torch.vtensor<[4,6],f32>, !torch.list, !torch.int -> !torch.list> +// CHECK: return %[[VAL_6]] : !torch.list> + func.func @test_split_to_sequence_with_list(%arg0: !torch.vtensor<[4,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0 = torch.operator "onnx.SplitToSequence"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4,6],f32>, !torch.vtensor<[2],si64>) -> !torch.list> + return %0 : !torch.list> + }