[torch] Add OnnxToTorch lowering for Onnx.STFT op (#3492)

Adds OnnxToTorch lowering for `Onnx.STFT` op.
pull/3495/head
Vinayak Dev 2024-06-25 19:00:45 +05:30 committed by GitHub
parent 3c3fbe4680
commit 02340408b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 564 additions and 0 deletions

View File

@ -12533,6 +12533,36 @@ def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [
let hasVerifier = 1;
}
def Torch_AtenStftOp : Torch_Op<"aten.stft", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$n_fft,
AnyTorchOptionalIntType:$hop_length,
AnyTorchOptionalIntType:$win_length,
AnyTorchOptionalTensorType:$window,
Torch_BoolType:$normalized,
AnyTorchOptionalBoolType:$onesided,
AnyTorchOptionalBoolType:$return_complex
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void AtenStftOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -3300,4 +3300,170 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
patterns.onOp(
"STFT", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// operands in order ->(signal, frameStep, window, frameLength*)
SmallVector<Value> operands;
int64_t onesided;
Torch::ValueTensorType resultType;
if (binder.tensorOperandsList(operands) ||
binder.s64IntegerAttr(onesided, "onesided", 1) ||
binder.tensorResultType(resultType))
return failure();
Value signal = operands[0];
Value frameStep = operands[1];
auto signalTy = cast<Torch::ValueTensorType>(signal.getType());
auto signalShape = signalTy.getSizes();
auto resultShape = resultType.getSizes();
// There are two possible cases for optional inputs frameLength and
// window, which are that either 4 operands will be passed with window
// being !torch.none, or three operands will be passed, with window
// present and frameLength absent. In the former case, we simply create
// a rectangular window consisting of ones, and in the latter, we set
// frameLength equal to the the inputShape[-2] or windowShape[0]
// depending upon whether window was present or not. Note that it is
// possible that both window and frameLength can be none, which would
// mean that either only two operands were passed, or, in case of three
// operands, window was passed in as none, and frameLength was absent.
Value window = nullptr, frameLength = nullptr;
bool windowIsNone = true, frameLengthIsNone = true;
if (operands.size() == 3) {
window = operands[2];
windowIsNone = isa<Torch::NoneType>(window.getType());
}
if (operands.size() == 4) {
window = operands[2];
frameLength = operands[3];
windowIsNone = isa<Torch::NoneType>(window.getType());
frameLengthIsNone = isa<Torch::NoneType>(frameLength.getType());
}
ArrayRef<int64_t> windowShape;
if (frameLengthIsNone) {
if (windowIsNone) {
frameLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
signalShape[signalShape.size() - 2]));
} else {
windowShape =
cast<Torch::ValueTensorType>(window.getType()).getSizes();
frameLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
}
}
Value frameLengthItem;
if (!frameLengthIsNone || windowIsNone) {
frameLengthItem =
getItemOp<Torch::IntType>(binder, rewriter, frameLength);
} else {
frameLengthItem = frameLength;
}
Value frameStepItem =
getItemOp<Torch::IntType>(binder, rewriter, frameStep);
if (windowIsNone) {
auto onesResultTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({-1}), signalTy.getDtype());
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value sizes = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
SmallVector<Value>{frameLengthItem});
window = rewriter.create<Torch::AtenOnesOp>(
binder.getLoc(), onesResultTy, sizes, none, none, none, none);
}
FailureOr<Type> complexDtype;
if (signalTy.getDtype().isBF16()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support for bfloat16 type is unimplemented.");
}
if (signalTy.getDtype().isF16()) {
complexDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
torch::torch_upstream::ScalarType::ComplexHalf);
} else if (signalTy.getDtype().isF32()) {
complexDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
torch::torch_upstream::ScalarType::ComplexFloat);
} else {
complexDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
torch::torch_upstream::ScalarType::ComplexDouble);
}
auto complexSignalTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
complexDtype.value());
// The onnx STFT op always passes in a float input, and if the input
// is intended to be complex, its shape will be [batch][length][2],
// where [...][0] is the real component, and [...][1] is the complex
// component. This complex input has to be made torch compatible before
// being passed into torch.stft, so it is necessary to call
// AtenViewAsComplexOp. In case of real input, the shape of the signal
// will be [batch][length][1], and therefore it will have to be squeezed
// at dim=2, before being passed into torch.stft.
if (signalShape[2] == 2) {
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
binder.getLoc(), complexSignalTy, signal);
} else {
Value two = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
auto newSignalTy = signalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
signalTy.getDtype());
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(), newSignalTy, signal, two);
}
// In case the window is not given, we use frameLength
// as the length of the window.
Value windowLen;
if (!windowIsNone) {
windowLen = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
} else {
windowLen = frameLengthItem;
}
Value falseVal =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value trueVal =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
auto stftTy = complexSignalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({resultShape[0], resultShape[2], resultShape[1]}),
complexSignalTy.getDtype());
// After torch.stft is called and the result is stored into the value
// stft, there is one thing to note: The resultType for the onnx op
// will have shape [batch][num_frames][length][2], while the shape of
// stft will be [batch][length][num_frames]. Before the value is
// converted to real through torch.view_as_real, we must permute the
// shape of stft to match the shape of resultType. Also, it is
// immaterial whether torch.view_as_real is called after or before the
// permutation; both outputs will be equivalent.
Value stft = rewriter.create<Torch::AtenStftOp>(
binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem,
windowLen, window, falseVal, onesided ? trueVal : falseVal,
trueVal);
auto permuteStftTy = complexSignalTy.getWithSizesAndDtype(
ArrayRef<int64_t>({resultShape[0], resultShape[1], resultShape[2]}),
complexSignalTy.getDtype());
Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1});
Value permutedStft = rewriter.create<Torch::AtenPermuteOp>(
binder.getLoc(), permuteStftTy, stft, permuteDims);
rewriter.replaceOpWithNewOp<Torch::AtenViewAsRealOp>(
binder.op, resultType, permutedStft);
return success();
});
}

View File

@ -10143,6 +10143,125 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.optional<bool>, %arg7: !torch.optional<bool>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n"
" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %25 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.optional<int>) {\n"
" %24 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
" torch.prim.If.yield %24 : !torch.optional<int>\n"
" } else {\n"
" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %25 = torch.derefine %24 : !torch.int to !torch.optional<int>\n"
" torch.prim.If.yield %25 : !torch.optional<int>\n"
" }\n"
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %24 : !torch.int\n"
" } else {\n"
" %24 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %24 : !torch.int\n"
" }\n"
" %9 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
" %24 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %24 : !torch.int\n"
" } else {\n"
" %24 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %24 : !torch.int\n"
" }\n"
" %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
" %24 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %24 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %12 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %13 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %14 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %15 = torch.aten.__isnot__ %5, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" torch.prim.If %15 -> () {\n"
" %24 = torch.prim.unchecked_cast %5 : !torch.optional<int> -> !torch.int\n"
" %25 = torch.aten.append.t %14, %24 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %16 = torch.aten.__is__ %arg6, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %24 = torch.prim.unchecked_cast %arg6 : !torch.optional<bool> -> !torch.bool\n"
" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n"
" torch.prim.If.yield %25 : !torch.bool\n"
" }\n"
" torch.prim.If %17 -> () {\n"
" %24 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %26 = torch.aten.append.t %14, %25 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" } else {\n"
" %24 = torch.aten.append.t %14, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" }\n"
" %18 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %19 = torch.aten.floordiv.int %18, %10 : !torch.int, !torch.int -> !torch.int\n"
" %20 = torch.aten.add.int %int1, %19 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.append.t %14, %20 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
" %23 = torch.prim.If %22 -> (!torch.bool) {\n"
" %24 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n"
" torch.prim.If.yield %25 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %23 -> () {\n"
" %24 = torch.aten.append.t %14, %int2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" return %14 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
@ -11607,6 +11726,143 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.optional<bool>, %arg7: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
" %int6 = torch.constant.int 6\n"
" %int9 = torch.constant.int 9\n"
" %int5 = torch.constant.int 5\n"
" %int8 = torch.constant.int 8\n"
" %none = torch.constant.none\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %0 = torch.prim.Uninitialized : !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %7 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %7 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
" %11 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
" %11 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
" %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n"
" %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n"
" } else {\n"
" %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n"
" } else {\n"
" %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n"
" } else {\n"
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n"
" } else {\n"
" %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
" %15 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %15 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
" %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n"
" } else {\n"
" %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
" } else {\n"
" %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n"
" } else {\n"
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
" } else {\n"
" %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %16 = torch.prim.If %15 -> (!torch.bool) {\n"
" %19 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %19 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
" %19 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
" %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %20 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
" } else {\n"
" %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
" } else {\n"
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n"
" }\n"
" %6 = torch.prim.If %5#0 -> (!torch.int) {\n"
" torch.prim.If.yield %5#1 : !torch.int\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.int\n"
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -1976,6 +1976,35 @@ def atenstack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
def atenfft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
return self
@check_shape_function([
Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window.
])
def atenstft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Optional[List[int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> List[int]:
assert len(self) == 1 or len(self) == 2, "Expected input tensor to be of shape (B?,L), where B is an optional batch dimension"
batch = None if len(self) == 1 else self[0]
length = self[0] if len(self) == 1 else self[1]
hop_length = (n_fft // 4) if hop_length is None else hop_length
assert n_fft > 0 and n_fft <= length, "Expected that 0 < n_fft <= len"
assert hop_length > 0, "Expected hop_length to be greater than 0"
out: List[int] = []
if batch is not None:
out.append(batch) # (B?,)
if onesided is None or onesided == True:
out.append(n_fft//2 + 1)
else:
out.append(n_fft) # (B?,N,)
# For this operator, center=False by default
out.append(1 + (length - n_fft)//hop_length) #(B?,N,T,)
if return_complex is not None and bool(return_complex) == False:
out.append(2) # a length-2 dimension of real and imaginary components. This gives output shape (B?,N,T,C?).
return out
class DummyClassType:
def __init__(self):
pass
@ -3307,6 +3336,37 @@ def atenfft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] =
else:
assert False, "Unsupported dtype"
@check_dtype_function([
Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32
Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64
Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=True), # output dtype = torch.complex64
Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=False), # output dtype = torch.float32
])
def atenstft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window_rank_dtype: Optional[Tuple[int, int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> int:
self_rank, self_dtype = self_rank_dtype
if is_complex_dtype(self_dtype) and return_complex is not None and return_complex:
return self_dtype
elif is_complex_dtype(self_dtype) and return_complex is not None and return_complex != True:
if self_dtype == torch.complex32:
return torch.float16
elif self_dtype == torch.complex64:
return torch.float32
elif self_dtype == torch.complex128:
return torch.float64
elif is_float_dtype(self_dtype) and return_complex is not None and return_complex:
if self_dtype == torch.float16:
return torch.complex32
elif self_dtype == torch.float32:
return torch.complex64
elif self_dtype == torch.float64:
return torch.complex128
elif is_float_dtype(self_dtype) and return_complex is not None and return_complex != True:
return self_dtype
elif is_integer_dtype(self_dtype):
return torch.complex64
assert False, "Unsupported dtype"
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))

View File

@ -921,6 +921,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)",
has_verifier=True,
)
emit(
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
)
# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")

View File

@ -2583,3 +2583,52 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !
%0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32>
return %0 : !torch.vtensor<[1,1,4,6],f32>
}
// -----
// CHECK-LABEL: func.func @test_stft
func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32>
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
// CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32>
return %0 : !torch.vtensor<[1,15,9,2],f32>
}
// -----
// CHECK-LABEL: func.func @test_stft_with_window
func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
// CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
// CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
// CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32>
return %0 : !torch.vtensor<[1,15,9,2],f32>
}