From 99848265c388099f500de9eac235bf0e2c9ccc0d Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Mon, 23 Sep 2024 06:39:29 +0000 Subject: [PATCH] [onnx] Relax constraints on input tensors in `onnx.STFT` conversion to torch dialect (#3676) - When the signal tensor is real, onnx allows its shape to be `[batch][length]` as well as `[batch][length][1]`. - Onnx also allows to specify `frame_length` together with `window` (not empty), given that it matches the window size. - Adding checks on signal and result shapes. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 63 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 47 ++++++++++++++ 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 68868e95c..36c26f26c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -3591,15 +3591,34 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value signal = operands[0]; Value frameStep = operands[1]; auto signalTy = cast(signal.getType()); + if (!signalTy || !signalTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected signal type having sizes"); + } auto signalShape = signalTy.getSizes(); + // The infrastructure of ONNX and onnxruntime supports a rank-2. + // For reference: + // https://github.com/onnx/onnx/blob/060589cb81dfb081ed912c9e722b15fe1dbc1a14/onnx/defs/math/defs.cc#L3475-L3477 + if (signalShape.size() != 2 && signalShape.size() != 3) { + return rewriter.notifyMatchFailure(binder.op, + "signal has invalid shape."); + } + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } auto resultShape = resultType.getSizes(); + if (resultShape.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "result has invalid shape."); + } // There are two possible cases for optional inputs frameLength and // window, which are that either 4 operands will be passed with window // being !torch.none, or three operands will be passed, with window // present and frameLength absent. In the former case, we simply create // a rectangular window consisting of ones, and in the latter, we set - // frameLength equal to the the inputShape[-2] or windowShape[0] + // frameLength equal to the the inputShape[1] or windowShape[0] // depending upon whether window was present or not. Note that it is // possible that both window and frameLength can be none, which would // mean that either only two operands were passed, or, in case of three @@ -3618,14 +3637,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } ArrayRef windowShape; + if (!windowIsNone) { + windowShape = + cast(window.getType()).getSizes(); + if (windowShape.size() != 1) { + return rewriter.notifyMatchFailure(binder.op, + "window has invalid shape."); + } + } if (frameLengthIsNone) { if (windowIsNone) { frameLength = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - signalShape[signalShape.size() - 2])); + binder.getLoc(), rewriter.getI64IntegerAttr(signalShape[1])); } else { - windowShape = - cast(window.getType()).getSizes(); frameLength = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); } @@ -3685,19 +3709,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // component. This complex input has to be made torch compatible before // being passed into torch.stft, so it is necessary to call // AtenViewAsComplexOp. In case of real input, the shape of the signal - // will be [batch][length][1], and therefore it will have to be squeezed - // at dim=2, before being passed into torch.stft. - if (signalShape[2] == 2) { - signal = rewriter.create( - binder.getLoc(), complexSignalTy, signal); - } else { - Value two = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - auto newSignalTy = signalTy.getWithSizesAndDtype( - ArrayRef({signalShape[0], signalShape[1]}), - signalTy.getDtype()); - signal = rewriter.create( - binder.getLoc(), newSignalTy, signal, two); + // will be [batch][length] or [batch][length][1], and therefore it will + // have to be squeezed at dim=2 in the latter case, before being passed + // into torch.stft. + if (signalShape.size() == 3) { + if (signalShape[2] == 2) { + signal = rewriter.create( + binder.getLoc(), complexSignalTy, signal); + } else { + Value two = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + auto newSignalTy = signalTy.getWithSizesAndDtype( + ArrayRef({signalShape[0], signalShape[1]}), + signalTy.getDtype()); + signal = rewriter.create( + binder.getLoc(), newSignalTy, signal, two); + } } // In case the window is not given, we use frameLength diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index be14dccd4..af2a1e002 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2904,6 +2904,30 @@ func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor // ----- +// CHECK-LABEL: func.func @test_stft_real_rank2 +func.func @test_stft_real_rank2(%arg0: !torch.vtensor<[1,128],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list + // CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32> + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %arg0, %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_stft_with_window func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16 @@ -2927,6 +2951,29 @@ func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !t // ----- +// CHECK-LABEL: func.func @test_stft_with_window_and_framelen +func.func @test_stft_with_window_and_framelen(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg3 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16 + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + // CHECK-LABEL: @test_reversesequence_batch func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0