[onnx] `onnx.Split` may not have `num_outputs` which can be inferred (#3608)

The attribute does not exist in all variants of the operation. It can be
inferred from the number of results so we should just do that.
pull/3622/head
Rob Suderman 2024-08-08 16:17:38 -07:00 committed by GitHub
parent fd98476f77
commit 880e64bbbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 1 deletions

View File

@ -1695,7 +1695,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (binder.s64IntegerAttr(axis, "axis", 0))
return rewriter.notifyMatchFailure(binder.op,
"Failed to get axis attribute");
if (binder.s64IntegerAttr(numOutputs, "num_outputs", 2))
numOutputs = binder.op->getNumResults();
if (binder.s64IntegerAttr(numOutputs, "num_outputs", numOutputs))
return rewriter.notifyMatchFailure(
binder.op, "Failed to get num_outputs attribute");

View File

@ -1547,6 +1547,30 @@ func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_split_2d_split_no_num_outputs(
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[DIM:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SPLITS:.+]] = torch.constant.int 3
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[DIM]]
// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.int %[[SZ1]], %[[SPLITS]]
// CHECK-DAG: %[[SUB:.+]] = torch.aten.sub.int %[[ADD]], %[[ONE]]
// CHECK-DAG: %[[SLICESZ:.+]] = torch.aten.floordiv.int %[[SUB]], %[[SPLITS]]
// CHECK-DAG: %[[START1:.+]] = torch.aten.add.int %[[ZERO]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[ZERO]], %[[START1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
// CHECK-DAG: %[[START2:.+]] = torch.aten.add.int %[[START1]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START1]], %[[START2]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
// CHECK-DAG: %[[SLICE2:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START2]], %[[SZ1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32>
// CHECK: return %[[SLICE0]], %[[SLICE1]], %[[SLICE2]]
func.func @test_split_2d_split_no_num_outputs(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_tan