From 880e64bbbb84be0c9a674462a7897bafddef9adb Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 8 Aug 2024 16:17:38 -0700 Subject: [PATCH] [onnx] `onnx.Split` may not have `num_outputs` which can be inferred (#3608) The attribute does not exist in all variants of the operation. It can be inferred from the number of results so we should just do that. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 +++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 09f923a42..e4f0e4bc0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1695,7 +1695,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.s64IntegerAttr(axis, "axis", 0)) return rewriter.notifyMatchFailure(binder.op, "Failed to get axis attribute"); - if (binder.s64IntegerAttr(numOutputs, "num_outputs", 2)) + + numOutputs = binder.op->getNumResults(); + if (binder.s64IntegerAttr(numOutputs, "num_outputs", numOutputs)) return rewriter.notifyMatchFailure( binder.op, "Failed to get num_outputs attribute"); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index ce80527dc..80a754dae 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1547,6 +1547,30 @@ func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_split_2d_split_no_num_outputs( +// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-DAG: %[[DIM:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[SPLITS:.+]] = torch.constant.int 3 +// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 +// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[DIM]] +// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.int %[[SZ1]], %[[SPLITS]] +// CHECK-DAG: %[[SUB:.+]] = torch.aten.sub.int %[[ADD]], %[[ONE]] +// CHECK-DAG: %[[SLICESZ:.+]] = torch.aten.floordiv.int %[[SUB]], %[[SPLITS]] +// CHECK-DAG: %[[START1:.+]] = torch.aten.add.int %[[ZERO]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[ZERO]], %[[START1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[START2:.+]] = torch.aten.add.int %[[START1]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START1]], %[[START2]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[SLICE2:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START2]], %[[SZ1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[SLICE0]], %[[SLICE1]], %[[SLICE2]] +func.func @test_split_2d_split_no_num_outputs(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +} + // ----- // CHECK-LABEL: func.func @test_tan