mirror of https://github.com/llvm/torch-mlir
[onnx] Relax constraints on input tensors in `onnx.STFT` conversion to torch dialect (#3676)
- When the signal tensor is real, onnx allows its shape to be `[batch][length]` as well as `[batch][length][1]`. - Onnx also allows to specify `frame_length` together with `window` (not empty), given that it matches the window size. - Adding checks on signal and result shapes.pull/3724/head
parent
3f79a2982a
commit
99848265c3
|
@ -3591,15 +3591,34 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value signal = operands[0];
|
Value signal = operands[0];
|
||||||
Value frameStep = operands[1];
|
Value frameStep = operands[1];
|
||||||
auto signalTy = cast<Torch::ValueTensorType>(signal.getType());
|
auto signalTy = cast<Torch::ValueTensorType>(signal.getType());
|
||||||
|
if (!signalTy || !signalTy.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected signal type having sizes");
|
||||||
|
}
|
||||||
auto signalShape = signalTy.getSizes();
|
auto signalShape = signalTy.getSizes();
|
||||||
|
// The infrastructure of ONNX and onnxruntime supports a rank-2.
|
||||||
|
// For reference:
|
||||||
|
// https://github.com/onnx/onnx/blob/060589cb81dfb081ed912c9e722b15fe1dbc1a14/onnx/defs/math/defs.cc#L3475-L3477
|
||||||
|
if (signalShape.size() != 2 && signalShape.size() != 3) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"signal has invalid shape.");
|
||||||
|
}
|
||||||
|
if (!resultType || !resultType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected result type having sizes");
|
||||||
|
}
|
||||||
auto resultShape = resultType.getSizes();
|
auto resultShape = resultType.getSizes();
|
||||||
|
if (resultShape.size() != 4) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"result has invalid shape.");
|
||||||
|
}
|
||||||
|
|
||||||
// There are two possible cases for optional inputs frameLength and
|
// There are two possible cases for optional inputs frameLength and
|
||||||
// window, which are that either 4 operands will be passed with window
|
// window, which are that either 4 operands will be passed with window
|
||||||
// being !torch.none, or three operands will be passed, with window
|
// being !torch.none, or three operands will be passed, with window
|
||||||
// present and frameLength absent. In the former case, we simply create
|
// present and frameLength absent. In the former case, we simply create
|
||||||
// a rectangular window consisting of ones, and in the latter, we set
|
// a rectangular window consisting of ones, and in the latter, we set
|
||||||
// frameLength equal to the the inputShape[-2] or windowShape[0]
|
// frameLength equal to the the inputShape[1] or windowShape[0]
|
||||||
// depending upon whether window was present or not. Note that it is
|
// depending upon whether window was present or not. Note that it is
|
||||||
// possible that both window and frameLength can be none, which would
|
// possible that both window and frameLength can be none, which would
|
||||||
// mean that either only two operands were passed, or, in case of three
|
// mean that either only two operands were passed, or, in case of three
|
||||||
|
@ -3618,14 +3637,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> windowShape;
|
ArrayRef<int64_t> windowShape;
|
||||||
|
if (!windowIsNone) {
|
||||||
|
windowShape =
|
||||||
|
cast<Torch::ValueTensorType>(window.getType()).getSizes();
|
||||||
|
if (windowShape.size() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"window has invalid shape.");
|
||||||
|
}
|
||||||
|
}
|
||||||
if (frameLengthIsNone) {
|
if (frameLengthIsNone) {
|
||||||
if (windowIsNone) {
|
if (windowIsNone) {
|
||||||
frameLength = rewriter.create<Torch::ConstantIntOp>(
|
frameLength = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
binder.getLoc(), rewriter.getI64IntegerAttr(signalShape[1]));
|
||||||
signalShape[signalShape.size() - 2]));
|
|
||||||
} else {
|
} else {
|
||||||
windowShape =
|
|
||||||
cast<Torch::ValueTensorType>(window.getType()).getSizes();
|
|
||||||
frameLength = rewriter.create<Torch::ConstantIntOp>(
|
frameLength = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
|
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
|
||||||
}
|
}
|
||||||
|
@ -3685,19 +3709,22 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
// component. This complex input has to be made torch compatible before
|
// component. This complex input has to be made torch compatible before
|
||||||
// being passed into torch.stft, so it is necessary to call
|
// being passed into torch.stft, so it is necessary to call
|
||||||
// AtenViewAsComplexOp. In case of real input, the shape of the signal
|
// AtenViewAsComplexOp. In case of real input, the shape of the signal
|
||||||
// will be [batch][length][1], and therefore it will have to be squeezed
|
// will be [batch][length] or [batch][length][1], and therefore it will
|
||||||
// at dim=2, before being passed into torch.stft.
|
// have to be squeezed at dim=2 in the latter case, before being passed
|
||||||
if (signalShape[2] == 2) {
|
// into torch.stft.
|
||||||
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
|
if (signalShape.size() == 3) {
|
||||||
binder.getLoc(), complexSignalTy, signal);
|
if (signalShape[2] == 2) {
|
||||||
} else {
|
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
|
||||||
Value two = rewriter.create<Torch::ConstantIntOp>(
|
binder.getLoc(), complexSignalTy, signal);
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
} else {
|
||||||
auto newSignalTy = signalTy.getWithSizesAndDtype(
|
Value two = rewriter.create<Torch::ConstantIntOp>(
|
||||||
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
|
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
||||||
signalTy.getDtype());
|
auto newSignalTy = signalTy.getWithSizesAndDtype(
|
||||||
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
|
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
|
||||||
binder.getLoc(), newSignalTy, signal, two);
|
signalTy.getDtype());
|
||||||
|
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||||
|
binder.getLoc(), newSignalTy, signal, two);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// In case the window is not given, we use frameLength
|
// In case the window is not given, we use frameLength
|
||||||
|
|
|
@ -2904,6 +2904,30 @@ func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_stft_real_rank2
|
||||||
|
func.func @test_stft_real_rank2(%arg0: !torch.vtensor<[1,128],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[STFT:.*]] = torch.aten.stft %arg0, %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
|
||||||
|
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_stft_with_window
|
// CHECK-LABEL: func.func @test_stft_with_window
|
||||||
func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16
|
// CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16
|
||||||
|
@ -2927,6 +2951,29 @@ func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !t
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_stft_with_window_and_framelen
|
||||||
|
func.func @test_stft_with_window_and_framelen(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg3 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
|
||||||
|
// CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16
|
||||||
|
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
|
||||||
|
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,15,9,2],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_reversesequence_batch
|
// CHECK-LABEL: @test_reversesequence_batch
|
||||||
func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
|
Loading…
Reference in New Issue