[onnx] Relax constraints on input tensors in `onnx.STFT` conversion to torch dialect (#3676)

- When the signal tensor is real, onnx allows its shape to be
`[batch][length]` as well as `[batch][length][1]`.
- Onnx also allows to specify `frame_length` together with `window` (not
empty), given that it matches the window size.
- Adding checks on signal and result shapes.
pull/3644/merge
giacs-epic 2024-09-23 06:39:29 +00:00 committed by GitHub
parent 3f79a2982a
commit 99848265c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 18 deletions

View File

@ -3591,15 +3591,34 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value signal = operands[0];
Value frameStep = operands[1];
auto signalTy = cast<Torch::ValueTensorType>(signal.getType());
if (!signalTy || !signalTy.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected signal type having sizes");
}
auto signalShape = signalTy.getSizes();
// The infrastructure of ONNX and onnxruntime supports a rank-2.
// For reference:
// https://github.com/onnx/onnx/blob/060589cb81dfb081ed912c9e722b15fe1dbc1a14/onnx/defs/math/defs.cc#L3475-L3477
if (signalShape.size() != 2 && signalShape.size() != 3) {
return rewriter.notifyMatchFailure(binder.op,
"signal has invalid shape.");
}
if (!resultType || !resultType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected result type having sizes");
}
auto resultShape = resultType.getSizes();
if (resultShape.size() != 4) {
return rewriter.notifyMatchFailure(binder.op,
"result has invalid shape.");
}
// There are two possible cases for optional inputs frameLength and
// window, which are that either 4 operands will be passed with window
// being !torch.none, or three operands will be passed, with window
// present and frameLength absent. In the former case, we simply create
// a rectangular window consisting of ones, and in the latter, we set
// frameLength equal to the the inputShape[-2] or windowShape[0]
// frameLength equal to the the inputShape[1] or windowShape[0]
// depending upon whether window was present or not. Note that it is
// possible that both window and frameLength can be none, which would
// mean that either only two operands were passed, or, in case of three
@ -3618,14 +3637,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
}
ArrayRef<int64_t> windowShape;
if (!windowIsNone) {
windowShape =
cast<Torch::ValueTensorType>(window.getType()).getSizes();
if (windowShape.size() != 1) {
return rewriter.notifyMatchFailure(binder.op,
"window has invalid shape.");
}
}
if (frameLengthIsNone) {
if (windowIsNone) {
frameLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
signalShape[signalShape.size() - 2]));
binder.getLoc(), rewriter.getI64IntegerAttr(signalShape[1]));
} else {
windowShape =
cast<Torch::ValueTensorType>(window.getType()).getSizes();
frameLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
}
@ -3685,19 +3709,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
// component. This complex input has to be made torch compatible before
// being passed into torch.stft, so it is necessary to call
// AtenViewAsComplexOp. In case of real input, the shape of the signal
// will be [batch][length][1], and therefore it will have to be squeezed
// at dim=2, before being passed into torch.stft.
if (signalShape[2] == 2) {
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
binder.getLoc(), complexSignalTy, signal);
} else {
Value two = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
auto newSignalTy = signalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
signalTy.getDtype());
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(), newSignalTy, signal, two);
// will be [batch][length] or [batch][length][1], and therefore it will
// have to be squeezed at dim=2 in the latter case, before being passed
// into torch.stft.
if (signalShape.size() == 3) {
if (signalShape[2] == 2) {
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
binder.getLoc(), complexSignalTy, signal);
} else {
Value two = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
auto newSignalTy = signalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
signalTy.getDtype());
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(), newSignalTy, signal, two);
}
}
// In case the window is not given, we use frameLength

View File

@ -2904,6 +2904,30 @@ func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor
// -----
// CHECK-LABEL: func.func @test_stft_real_rank2
func.func @test_stft_real_rank2(%arg0: !torch.vtensor<[1,128],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32>
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[STFT:.*]] = torch.aten.stft %arg0, %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32>
return %0 : !torch.vtensor<[1,15,9,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_stft_with_window
func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16
@ -2927,6 +2951,29 @@ func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !t
// -----
// CHECK-LABEL: func.func @test_stft_with_window_and_framelen
func.func @test_stft_with_window_and_framelen(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg3 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
// CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
// CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32>
return %0 : !torch.vtensor<[1,15,9,2],f32>
}
// -----
// CHECK-LABEL: @test_reversesequence_batch
func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0