mirror of https://github.com/llvm/torch-mlir
[torch] Add OnnxToTorch lowering for Onnx.STFT op (#3492)
Adds OnnxToTorch lowering for `Onnx.STFT` op.pull/3495/head
parent
3c3fbe4680
commit
02340408b7
|
@ -12533,6 +12533,36 @@ def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenStftOp : Torch_Op<"aten.stft", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$n_fft,
|
||||
AnyTorchOptionalIntType:$hop_length,
|
||||
AnyTorchOptionalIntType:$win_length,
|
||||
AnyTorchOptionalTensorType:$window,
|
||||
Torch_BoolType:$normalized,
|
||||
AnyTorchOptionalBoolType:$onesided,
|
||||
AnyTorchOptionalBoolType:$return_complex
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 8, 1);
|
||||
}
|
||||
void AtenStftOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 8, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -3300,4 +3300,170 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*Torch_BoolType:$antialias*/ cstFalse);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"STFT", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// operands in order ->(signal, frameStep, window, frameLength*)
|
||||
SmallVector<Value> operands;
|
||||
int64_t onesided;
|
||||
Torch::ValueTensorType resultType;
|
||||
|
||||
if (binder.tensorOperandsList(operands) ||
|
||||
binder.s64IntegerAttr(onesided, "onesided", 1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
Value signal = operands[0];
|
||||
Value frameStep = operands[1];
|
||||
auto signalTy = cast<Torch::ValueTensorType>(signal.getType());
|
||||
auto signalShape = signalTy.getSizes();
|
||||
auto resultShape = resultType.getSizes();
|
||||
|
||||
// There are two possible cases for optional inputs frameLength and
|
||||
// window, which are that either 4 operands will be passed with window
|
||||
// being !torch.none, or three operands will be passed, with window
|
||||
// present and frameLength absent. In the former case, we simply create
|
||||
// a rectangular window consisting of ones, and in the latter, we set
|
||||
// frameLength equal to the the inputShape[-2] or windowShape[0]
|
||||
// depending upon whether window was present or not. Note that it is
|
||||
// possible that both window and frameLength can be none, which would
|
||||
// mean that either only two operands were passed, or, in case of three
|
||||
// operands, window was passed in as none, and frameLength was absent.
|
||||
Value window = nullptr, frameLength = nullptr;
|
||||
bool windowIsNone = true, frameLengthIsNone = true;
|
||||
if (operands.size() == 3) {
|
||||
window = operands[2];
|
||||
windowIsNone = isa<Torch::NoneType>(window.getType());
|
||||
}
|
||||
if (operands.size() == 4) {
|
||||
window = operands[2];
|
||||
frameLength = operands[3];
|
||||
windowIsNone = isa<Torch::NoneType>(window.getType());
|
||||
frameLengthIsNone = isa<Torch::NoneType>(frameLength.getType());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> windowShape;
|
||||
if (frameLengthIsNone) {
|
||||
if (windowIsNone) {
|
||||
frameLength = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
signalShape[signalShape.size() - 2]));
|
||||
} else {
|
||||
windowShape =
|
||||
cast<Torch::ValueTensorType>(window.getType()).getSizes();
|
||||
frameLength = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
|
||||
}
|
||||
}
|
||||
|
||||
Value frameLengthItem;
|
||||
if (!frameLengthIsNone || windowIsNone) {
|
||||
frameLengthItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, frameLength);
|
||||
} else {
|
||||
frameLengthItem = frameLength;
|
||||
}
|
||||
Value frameStepItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, frameStep);
|
||||
|
||||
if (windowIsNone) {
|
||||
auto onesResultTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({-1}), signalTy.getDtype());
|
||||
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value sizes = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(
|
||||
Torch::IntType::get(binder.op->getContext())),
|
||||
SmallVector<Value>{frameLengthItem});
|
||||
window = rewriter.create<Torch::AtenOnesOp>(
|
||||
binder.getLoc(), onesResultTy, sizes, none, none, none, none);
|
||||
}
|
||||
|
||||
FailureOr<Type> complexDtype;
|
||||
if (signalTy.getDtype().isBF16()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unimplemented: support for bfloat16 type is unimplemented.");
|
||||
}
|
||||
if (signalTy.getDtype().isF16()) {
|
||||
complexDtype = Torch::getTypeForScalarType(
|
||||
binder.op->getContext(),
|
||||
torch::torch_upstream::ScalarType::ComplexHalf);
|
||||
} else if (signalTy.getDtype().isF32()) {
|
||||
complexDtype = Torch::getTypeForScalarType(
|
||||
binder.op->getContext(),
|
||||
torch::torch_upstream::ScalarType::ComplexFloat);
|
||||
} else {
|
||||
complexDtype = Torch::getTypeForScalarType(
|
||||
binder.op->getContext(),
|
||||
torch::torch_upstream::ScalarType::ComplexDouble);
|
||||
}
|
||||
|
||||
auto complexSignalTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
|
||||
complexDtype.value());
|
||||
|
||||
// The onnx STFT op always passes in a float input, and if the input
|
||||
// is intended to be complex, its shape will be [batch][length][2],
|
||||
// where [...][0] is the real component, and [...][1] is the complex
|
||||
// component. This complex input has to be made torch compatible before
|
||||
// being passed into torch.stft, so it is necessary to call
|
||||
// AtenViewAsComplexOp. In case of real input, the shape of the signal
|
||||
// will be [batch][length][1], and therefore it will have to be squeezed
|
||||
// at dim=2, before being passed into torch.stft.
|
||||
if (signalShape[2] == 2) {
|
||||
signal = rewriter.create<Torch::AtenViewAsComplexOp>(
|
||||
binder.getLoc(), complexSignalTy, signal);
|
||||
} else {
|
||||
Value two = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
||||
auto newSignalTy = signalTy.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>({signalShape[0], signalShape[1]}),
|
||||
signalTy.getDtype());
|
||||
signal = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(), newSignalTy, signal, two);
|
||||
}
|
||||
|
||||
// In case the window is not given, we use frameLength
|
||||
// as the length of the window.
|
||||
Value windowLen;
|
||||
if (!windowIsNone) {
|
||||
windowLen = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0]));
|
||||
} else {
|
||||
windowLen = frameLengthItem;
|
||||
}
|
||||
|
||||
Value falseVal =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value trueVal =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
auto stftTy = complexSignalTy.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>({resultShape[0], resultShape[2], resultShape[1]}),
|
||||
complexSignalTy.getDtype());
|
||||
|
||||
// After torch.stft is called and the result is stored into the value
|
||||
// stft, there is one thing to note: The resultType for the onnx op
|
||||
// will have shape [batch][num_frames][length][2], while the shape of
|
||||
// stft will be [batch][length][num_frames]. Before the value is
|
||||
// converted to real through torch.view_as_real, we must permute the
|
||||
// shape of stft to match the shape of resultType. Also, it is
|
||||
// immaterial whether torch.view_as_real is called after or before the
|
||||
// permutation; both outputs will be equivalent.
|
||||
Value stft = rewriter.create<Torch::AtenStftOp>(
|
||||
binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem,
|
||||
windowLen, window, falseVal, onesided ? trueVal : falseVal,
|
||||
trueVal);
|
||||
|
||||
auto permuteStftTy = complexSignalTy.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>({resultShape[0], resultShape[1], resultShape[2]}),
|
||||
complexSignalTy.getDtype());
|
||||
Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1});
|
||||
Value permutedStft = rewriter.create<Torch::AtenPermuteOp>(
|
||||
binder.getLoc(), permuteStftTy, stft, permuteDims);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenViewAsRealOp>(
|
||||
binder.op, resultType, permutedStft);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -10143,6 +10143,125 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.optional<bool>, %arg7: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.optional<int>) {\n"
|
||||
" %24 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
|
||||
" torch.prim.If.yield %24 : !torch.optional<int>\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %25 = torch.derefine %24 : !torch.int to !torch.optional<int>\n"
|
||||
" torch.prim.If.yield %25 : !torch.optional<int>\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
|
||||
" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %24 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %24 : !torch.int\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
|
||||
" %24 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %24 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %24 : !torch.int\n"
|
||||
" }\n"
|
||||
" %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
|
||||
" %24 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %24 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %12 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %13 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %14 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %15 = torch.aten.__isnot__ %5, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If %15 -> () {\n"
|
||||
" %24 = torch.prim.unchecked_cast %5 : !torch.optional<int> -> !torch.int\n"
|
||||
" %25 = torch.aten.append.t %14, %24 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %16 = torch.aten.__is__ %arg6, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
|
||||
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.prim.unchecked_cast %arg6 : !torch.optional<bool> -> !torch.bool\n"
|
||||
" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n"
|
||||
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %17 -> () {\n"
|
||||
" %24 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %26 = torch.aten.append.t %14, %25 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.append.t %14, %arg1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %18 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %19 = torch.aten.floordiv.int %18, %10 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.aten.add.int %int1, %19 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %21 = torch.aten.append.t %14, %20 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
|
||||
" %23 = torch.prim.If %22 -> (!torch.bool) {\n"
|
||||
" %24 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
|
||||
" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n"
|
||||
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %23 -> () {\n"
|
||||
" %24 = torch.aten.append.t %14, %int2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %14 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||
|
@ -11607,6 +11726,143 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.optional<bool>, %arg7: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int10 = torch.constant.int 10\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int9 = torch.constant.int 9\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int8 = torch.constant.int 8\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = torch.prim.Uninitialized : !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
|
||||
" %7 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" %7 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
|
||||
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
|
||||
" %11 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
||||
" %11 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
|
||||
" %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n"
|
||||
" %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %12 = torch.prim.If %11 -> (!torch.bool) {\n"
|
||||
" %15 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" %15 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n"
|
||||
" %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %16 = torch.prim.If %15 -> (!torch.bool) {\n"
|
||||
" %19 = torch.aten.__isnot__ %arg7, %none : !torch.optional<bool>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %19 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %17 = torch.prim.If %16 -> (!torch.bool) {\n"
|
||||
" %19 = torch.prim.unchecked_cast %arg7 : !torch.optional<bool> -> !torch.bool\n"
|
||||
" %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If.yield %20 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n"
|
||||
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n"
|
||||
" }\n"
|
||||
" %6 = torch.prim.If %5#0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %5#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -1976,6 +1976,35 @@ def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
|
|||
def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
|
||||
return self
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window.
|
||||
])
|
||||
def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Optional[List[int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> List[int]:
|
||||
assert len(self) == 1 or len(self) == 2, "Expected input tensor to be of shape (B?,L), where B is an optional batch dimension"
|
||||
|
||||
batch = None if len(self) == 1 else self[0]
|
||||
length = self[0] if len(self) == 1 else self[1]
|
||||
hop_length = (n_fft // 4) if hop_length is None else hop_length
|
||||
assert n_fft > 0 and n_fft <= length, "Expected that 0 < n_fft <= len"
|
||||
assert hop_length > 0, "Expected hop_length to be greater than 0"
|
||||
|
||||
out: List[int] = []
|
||||
if batch is not None:
|
||||
out.append(batch) # (B?,)
|
||||
|
||||
if onesided is None or onesided == True:
|
||||
out.append(n_fft//2 + 1)
|
||||
else:
|
||||
out.append(n_fft) # (B?,N,)
|
||||
|
||||
# For this operator, center=False by default
|
||||
out.append(1 + (length - n_fft)//hop_length) #(B?,N,T,)
|
||||
|
||||
if return_complex is not None and bool(return_complex) == False:
|
||||
out.append(2) # a length-2 dimension of real and imaginary components. This gives output shape (B?,N,T,C?).
|
||||
|
||||
return out
|
||||
|
||||
class DummyClassType:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -3307,6 +3336,37 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] =
|
|||
else:
|
||||
assert False, "Unsupported dtype"
|
||||
|
||||
@check_dtype_function([
|
||||
Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32
|
||||
Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64
|
||||
Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=True), # output dtype = torch.complex64
|
||||
Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=False), # output dtype = torch.float32
|
||||
])
|
||||
def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window_rank_dtype: Optional[Tuple[int, int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if is_complex_dtype(self_dtype) and return_complex is not None and return_complex:
|
||||
return self_dtype
|
||||
elif is_complex_dtype(self_dtype) and return_complex is not None and return_complex != True:
|
||||
if self_dtype == torch.complex32:
|
||||
return torch.float16
|
||||
elif self_dtype == torch.complex64:
|
||||
return torch.float32
|
||||
elif self_dtype == torch.complex128:
|
||||
return torch.float64
|
||||
elif is_float_dtype(self_dtype) and return_complex is not None and return_complex:
|
||||
if self_dtype == torch.float16:
|
||||
return torch.complex32
|
||||
elif self_dtype == torch.float32:
|
||||
return torch.complex64
|
||||
elif self_dtype == torch.float64:
|
||||
return torch.complex128
|
||||
elif is_float_dtype(self_dtype) and return_complex is not None and return_complex != True:
|
||||
return self_dtype
|
||||
elif is_integer_dtype(self_dtype):
|
||||
return torch.complex64
|
||||
|
||||
assert False, "Unsupported dtype"
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
|
|
|
@ -921,6 +921,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)",
|
||||
has_verifier=True,
|
||||
)
|
||||
emit(
|
||||
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
|
||||
)
|
||||
|
||||
# Functionalization ops
|
||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -2583,3 +2583,52 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !
|
|||
%0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32>
|
||||
return %0 : !torch.vtensor<[1,1,4,6],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_stft
|
||||
func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
|
||||
// CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
|
||||
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
|
||||
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
|
||||
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32>
|
||||
return %0 : !torch.vtensor<[1,15,9,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_stft_with_window
|
||||
func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16
|
||||
// CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32>
|
||||
// CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16
|
||||
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex<f32>>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex<f32>>, !torch.list<int> -> !torch.vtensor<[1,15,9],complex<f32>>
|
||||
// CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex<f32>> -> !torch.vtensor<[1,15,9,2],f32>
|
||||
// CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32>
|
||||
%0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32>
|
||||
return %0 : !torch.vtensor<[1,15,9,2],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue