mirror of https://github.com/llvm/torch-mlir
Implementation of SplitToSequence ops lowering (#3509)
Added support for splitToSequence ops lowering Added test case with filecheckpull/3523/head
parent
b6e4725259
commit
a211ccbcff
|
@ -4016,4 +4016,83 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, resultType, scatter, constZero, unflattenSizeList);
|
||||
return success();
|
||||
});
|
||||
// split to sequence
|
||||
// Arguments:
|
||||
// - input: the tensor to split
|
||||
// -Split(optional): Length of each output
|
||||
// Attributes:
|
||||
// - axis: the axis along which to split the input
|
||||
// - keepdims: to keep the split dimension or not. Ignored when 'split' is
|
||||
// specified Outputs:
|
||||
// - outputs: sequence of tensor
|
||||
//
|
||||
|
||||
patterns.onOp(
|
||||
"SplitToSequence", 11,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Value self;
|
||||
Value split;
|
||||
int64_t axis;
|
||||
int64_t keepdims;
|
||||
Torch::ListType resultType;
|
||||
|
||||
if (binder.op->getNumOperands() == 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "No of operands should be two.Keepdims attribute is "
|
||||
"not yet implemented");
|
||||
|
||||
if (binder.tensorOperandAtIndex(self, 0) ||
|
||||
binder.tensorListResultType(resultType) ||
|
||||
binder.s64IntegerAttr(keepdims, "keepdims", 1) ||
|
||||
binder.tensorOperandAtIndex(split, 1) ||
|
||||
binder.s64IntegerAttr(axis, "axis", 0))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Not converting to AtenSplitToSequenceOp due to inputs ");
|
||||
|
||||
Value axisValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(axis));
|
||||
auto splitTy = cast<Torch::ValueTensorType>(split.getType());
|
||||
|
||||
if (!splitTy || !splitTy.hasSizes())
|
||||
return failure();
|
||||
|
||||
auto splitSizes = splitTy.getSizes();
|
||||
unsigned splitDim = splitTy.getSizes().size();
|
||||
|
||||
if (splitDim > 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Split should be scalar or 1-D Tensor ");
|
||||
|
||||
if (splitDim == 1) {
|
||||
if (splitSizes[0] == Torch::kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Dynamic shapes for Split is not yet supported");
|
||||
} else if (splitSizes[0] <=
|
||||
1) { // dealing with 1/0 element in 1-D tensor
|
||||
Value splitInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), split);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSplitTensorOp>(
|
||||
binder.op, resultType, self, splitInt, axisValue);
|
||||
return success();
|
||||
} else {
|
||||
// Handling multiple elment in split
|
||||
Value shapeList =
|
||||
createConstantIntList(binder, rewriter, splitSizes);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSplitSizesOp>(
|
||||
binder.op, resultType, self, shapeList, axisValue);
|
||||
return success();
|
||||
}
|
||||
} else if (splitDim == 0) { // Handle 0-D tensor
|
||||
Value splitInt = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), split);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSplitTensorOp>(
|
||||
binder.op, resultType, self, splitInt, axisValue);
|
||||
return success();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Handling of this kind of inputs is not there");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
@ -3108,3 +3108,56 @@ func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.
|
|||
%0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) {torch.onnx.reduction = "min"} : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32>
|
||||
return %0 : !torch.vtensor<[4,4,4],f32>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func.func @test_split_to_sequence_1
|
||||
func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list<vtensor<[3,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_0:.*]]: !torch.vtensor<[3,6],f32>
|
||||
// CHECK: %[[VAL_1:.*]]: !torch.vtensor<[1],si64>) -> !torch.list<vtensor<[3,6],f32>>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_5:.*]] = torch.aten.split.Tensor %[[VAL_0]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,6],f32>, !torch.int, !torch.int -> !torch.list<vtensor<[3,6],f32>>
|
||||
// CHECK: return %[[VAL_5]] : !torch.list<vtensor<[3,6],f32>>
|
||||
%none = torch.constant.none
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int
|
||||
%1 = torch.aten.split.Tensor %arg0, %0, %int1 : !torch.vtensor<[3,6],f32>, !torch.int, !torch.int -> !torch.list<vtensor<[3,6],f32>>
|
||||
return %1 : !torch.list<vtensor<[3,6],f32>>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func.func @test_split_to_sequence_2
|
||||
func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list<vtensor<[1,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_0:.*]]: !torch.vtensor<[2,6],f32>
|
||||
// CHECK: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.list<vtensor<[1,6],f32>>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_5:.*]] = torch.aten.split.Tensor %[[VAL_0]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[2,6],f32>, !torch.int, !torch.int -> !torch.list<vtensor<[1,6],f32>>
|
||||
// CHECK: return %[[VAL_5]] : !torch.list<vtensor<[1,6],f32>>
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||
%1 = torch.aten.split.Tensor %arg0, %0, %int0 : !torch.vtensor<[2,6],f32>, !torch.int, !torch.int -> !torch.list<vtensor<[1,6],f32>>
|
||||
return %1 : !torch.list<vtensor<[1,6],f32>>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: func.func @test_split_to_sequence_with_list(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.list<vtensor<[2,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.aten.split.sizes %[[VAL_0]], %[[VAL_5]], %[[VAL_3]] : !torch.vtensor<[4,6],f32>, !torch.list<int>, !torch.int -> !torch.list<vtensor<[2,6],f32>>
|
||||
// CHECK: return %[[VAL_6]] : !torch.list<vtensor<[2,6],f32>>
|
||||
func.func @test_split_to_sequence_with_list(%arg0: !torch.vtensor<[4,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.list<vtensor<[2,6],f32>> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.SplitToSequence"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4,6],f32>, !torch.vtensor<[2],si64>) -> !torch.list<vtensor<[2,6],f32>>
|
||||
return %0 : !torch.list<vtensor<[2,6],f32>>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue