diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index baac6d963..0ca182d3c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -640,8 +640,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value numMelBinsItem = getItemOp(binder, rewriter, operands[0]); - Value dftLengthItem = - getItemOp(binder, rewriter, operands[1]); Value sampleRateItem = getItemOp(binder, rewriter, operands[2]); Value lowerEdgeHzItem = @@ -656,9 +654,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Recurring shapes SmallVector unranked({}); SmallVector shapeNMB({numMelBinsInt}); - SmallVector shapeNMBp2({numMelBinsInt + 2}); SmallVector shape1xNMB({1, numMelBinsInt}); SmallVector shapeNSB({numSpectrogramBinsInt}); + SmallVector shapeNSBx1({numSpectrogramBinsInt, 1}); SmallVector shapeNSBxNMB( {numSpectrogramBinsInt, numMelBinsInt}); @@ -671,37 +669,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Value constants Value noneConst = b.create(); - Value negTwoConst = - b.create(rewriter.getI64IntegerAttr(-2)); - Value negOneConst = - b.create(rewriter.getI64IntegerAttr(-1)); Value zeroConst = b.create(rewriter.getI64IntegerAttr(0)); Value oneConst = b.create(rewriter.getI64IntegerAttr(1)); Value twoConst = b.create(rewriter.getI64IntegerAttr(2)); + Value int32DTypeConst = + b.create(rewriter.getI64IntegerAttr(3)); Value float32DTypeConst = b.create(rewriter.getI64IntegerAttr(6)); Torch::ValueTensorType dftLenType = Torch::ValueTensorType::get(ctx, unranked, inpIntDType); Type freqBinsIntType = - Torch::ValueTensorType::get(ctx, shapeNMBp2, si32Ty); + Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty); Type freqBinsFltType = - Torch::ValueTensorType::get(ctx, shapeNMBp2, f32Ty); + Torch::ValueTensorType::get(ctx, shapeNMB, f32Ty); - Value dftLengthDivTwoFlt = - b.create(dftLengthItem, twoConst); - Value dftLengthDivTwo = - b.create(dftLengthDivTwoFlt); - Value numSpectrogramBins = - b.create(dftLengthDivTwo, oneConst); - Value numSpectrogramBinsItem = numSpectrogramBins; - Value freqBinsInit = b.create( - freqBinsIntType, numMelBinsItem, /*dtype=*/float32DTypeConst, - /*layout=*/noneConst, /*device=*/noneConst, - /*pin_memory=*/noneConst); + Value dftLengthDivTwoTensor = b.create( + dftLenType, operands[1], twoConst); + Value numSpectrogramBinsTensor = b.create( + dftLenType, dftLengthDivTwoTensor, oneConst, /*alpha =*/oneConst); + Value numSpectrogramBinsItem = getItemOp( + binder, rewriter, numSpectrogramBinsTensor); // From Ref Impl of Onnx.MelWeightMatrix: // https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32 @@ -712,6 +703,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( b.create(rewriter.getF64FloatAttr(700)); Value tenConst = b.create(rewriter.getF64FloatAttr(10)); + Value oneFltConst = + b.create(rewriter.getF64FloatAttr(1)); + Value LnToLog10Const = b.create( + rewriter.getF64FloatAttr(M_LOG10E)); Value lfDiv7Hfloat = b.create(lowerEdgeHzItem, sevenHConst); @@ -720,8 +715,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( b.create(freqType, lfDiv7Hfloat); Value lfDiv7HAdd1 = b.create( freqType, lfDiv7H, oneConst, /*alpha =*/oneConst); - Value lfDiv7HAdd1Log10 = - b.create(freqType, lfDiv7HAdd1); + Value lfDiv7HAdd1Ln = b.create(freqType, lfDiv7HAdd1); + Value lfDiv7HAdd1Log10 = b.create( + freqType, lfDiv7HAdd1Ln, LnToLog10Const); + Value lfMel = b.create( freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst); @@ -731,226 +728,235 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( b.create(freqType, hfDiv7Hfloat); Value hfDiv7HAdd1 = b.create( freqType, hfDiv7H, oneConst, /*alpha =*/oneConst); - Value hfDiv7HAdd1Log10 = - b.create(freqType, hfDiv7HAdd1); + Value hfDiv7HAdd1Ln = b.create(freqType, hfDiv7HAdd1); + Value hfDiv7HAdd1Log10 = b.create( + freqType, hfDiv7HAdd1Ln, LnToLog10Const); + Value hfMel = b.create( freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst); Value hfSubLf = b.create( hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst); + Value numMelBinsPlus2 = + b.create(numMelBinsItem, twoConst); Value melStep = b.create( - hfSubLf.getType(), hfSubLf, numMelBinsItem); + hfSubLf.getType(), hfSubLf, numMelBinsPlus2); - Value freqBinsMulMelStep = b.create( - freqBinsFltType, freqBinsInit, melStep); - Value freqBinsScaled = b.create( - freqBinsFltType, freqBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value lowBinsInit = b.create( + freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); - // Mel to Hz conv + Value centerBinsInit = b.create( + freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); - Value fbDiv = b.create( - freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst); - Value fbClone = b.create( - freqBinsFltType, freqBinsScaled, /*memory_format=*/noneConst); - Value tenTensor = b.create(freqBinsFltType, - fbClone, tenConst); - Value fbPow = b.create(freqBinsFltType, - tenTensor, fbDiv); - Value fbPowSubOne = b.create( - freqBinsFltType, fbPow, oneConst, /*alpha=*/oneConst); - Value freqBinsHz = b.create( - freqBinsFltType, fbPowSubOne, sevenHConst); + Value highBinsInit = b.create( + freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); - // Normalize freqBinsHz + // Common values used in conversion Value dftLenPlusOne = b.create( dftLenType, operands[1], oneConst, /*alpha=*/oneConst); Value dftLenPlusOneItem = getItemOp(binder, rewriter, dftLenPlusOne); - Value fbMulDft = b.create( - freqBinsFltType, freqBinsHz, dftLenPlusOneItem); - Value freqBinsNormalized = b.create( - freqBinsFltType, fbMulDft, sampleRateItem); - - // cast to int32 - Value int32DTypeConst = - b.create(rewriter.getI64IntegerAttr(3)); Value falseConst = b.create(false); - Value freqBins = b.create( - freqBinsIntType, freqBinsNormalized, /*dtype=*/int32DTypeConst, + Torch::ValueTensorType unsqueezeBinsResType = + Torch::ValueTensorType::get(ctx, shape1xNMB, si32Ty); + + // Low bins Mel to hz + Value lowBinsMulMelStep = b.create( + freqBinsFltType, lowBinsInit, melStep); + Value lowBinsScaled = b.create( + freqBinsFltType, lowBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value lbDiv = b.create( + freqBinsFltType, lowBinsScaled, twoFiveNineFiveConst); + Value lbClone = b.create( + freqBinsFltType, lowBinsScaled, /*memory_format=*/noneConst); + Value lbTenTensor = b.create( + freqBinsFltType, lbClone, tenConst); + Value lbPow = b.create( + freqBinsFltType, lbTenTensor, lbDiv); + Value lbPowSubOne = b.create( + freqBinsFltType, lbPow, oneConst, /*alpha=*/oneConst); + Value lowBinsHz = b.create( + freqBinsFltType, lbPowSubOne, sevenHConst); + // Normalize freqBinsHz + Value lbMulDft = b.create( + freqBinsFltType, lowBinsHz, dftLenPlusOneItem); + Value lowBinsNormalized = b.create( + freqBinsFltType, lbMulDft, sampleRateItem); + // cast to int32 + Value lowBinsInt = b.create( + freqBinsIntType, lowBinsNormalized, /*dtype=*/int32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); + Value lowBins = b.create( + unsqueezeBinsResType, lowBinsInt, /*dim=*/zeroConst); - Torch::ValueTensorType sliceResType = - Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty); - Type unsqueezeResType = - sliceResType.getWithSizesAndDtype(shape1xNMB, si32Ty); - Value lfTensor = b.create( - sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst, - /*end=*/negTwoConst, /*step=*/oneConst); - Value lowFreqTensor = b.create( - unsqueezeResType, lfTensor, /*dim=*/zeroConst); + // Center bins mel to hz + Value centerBinsInitInc = b.create( + freqBinsIntType, centerBinsInit, oneConst, /*alpha=*/oneConst); + Value centerBinsMulMelStep = b.create( + freqBinsFltType, centerBinsInitInc, melStep); + Value centerBinsScaled = b.create( + freqBinsFltType, centerBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value cbDiv = b.create( + freqBinsFltType, centerBinsScaled, twoFiveNineFiveConst); + Value cbClone = b.create( + freqBinsFltType, centerBinsScaled, /*memory_format=*/noneConst); + Value cbTenTensor = b.create( + freqBinsFltType, cbClone, tenConst); + Value cbPow = b.create( + freqBinsFltType, cbTenTensor, cbDiv); + Value cbPowSubOne = b.create( + freqBinsFltType, cbPow, oneConst, /*alpha=*/oneConst); + Value centerBinsHz = b.create( + freqBinsFltType, cbPowSubOne, sevenHConst); + // Normalize freqBinsHz + Value cbMulDft = b.create( + freqBinsFltType, centerBinsHz, dftLenPlusOneItem); + Value centerBinsNormalized = b.create( + freqBinsFltType, cbMulDft, sampleRateItem); + // cast to int32 + Value centerBinsInt = b.create( + freqBinsIntType, centerBinsNormalized, /*dtype=*/int32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value centerBins = b.create( + unsqueezeBinsResType, centerBinsInt, /*dim=*/zeroConst); - Value cfTensor = b.create( - sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/oneConst, - /*end=*/negOneConst, /*step=*/oneConst); - Value centerFreqTensor = b.create( - unsqueezeResType, cfTensor, /*dim=*/zeroConst); + // High bins mel to hz + Value highBinsInitInc = b.create( + freqBinsIntType, highBinsInit, twoConst, /*alpha=*/oneConst); + Value highBinsMulMelStep = b.create( + freqBinsFltType, highBinsInitInc, melStep); + Value highBinsScaled = b.create( + freqBinsFltType, highBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value hbDiv = b.create( + freqBinsFltType, highBinsScaled, twoFiveNineFiveConst); + Value hbClone = b.create( + freqBinsFltType, highBinsScaled, /*memory_format=*/noneConst); + Value hbTenTensor = b.create( + freqBinsFltType, hbClone, tenConst); + Value hbPow = b.create( + freqBinsFltType, hbTenTensor, hbDiv); + Value hbPowSubOne = b.create( + freqBinsFltType, hbPow, oneConst, /*alpha=*/oneConst); + Value highBinsHz = b.create( + freqBinsFltType, hbPowSubOne, sevenHConst); + // Normalize freqBinsHz + Value hbMulDft = b.create( + freqBinsFltType, highBinsHz, dftLenPlusOneItem); + Value highBinsNormalized = b.create( + freqBinsFltType, hbMulDft, sampleRateItem); + // cast to int32 + Value highBinsInt = b.create( + freqBinsIntType, highBinsNormalized, /*dtype=*/int32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value highBins = b.create( + unsqueezeBinsResType, highBinsInt, /*dim=*/zeroConst); - Value hfTensor = b.create( - sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst, - /*end=*/noneConst, /*step=*/oneConst); - Value highFreqTensor = b.create( - unsqueezeResType, hfTensor, /*dim=*/zeroConst); - - Value lowToCenter = - b.create(unsqueezeResType, centerFreqTensor, - lowFreqTensor, /*alpha=*/oneConst); - Value centerToHigh = b.create( - unsqueezeResType, highFreqTensor, centerFreqTensor, - /*alpha=*/oneConst); - - Type zeroToNInitType = - inputIntType.getWithSizesAndDtype(shapeNSB, f32Ty); - Value zeroToNInit = b.create( - zeroToNInitType, numSpectrogramBinsItem, - /*dtype=*/float32DTypeConst, + Type iotaInitType = inputIntType.getWithSizesAndDtype(shapeNSB, si32Ty); + Value iotaInit = b.create( + iotaInitType, numSpectrogramBinsItem, + /*dtype=*/int32DTypeConst, /*layout=*/noneConst, /*device=*/noneConst, /*pin_memory=*/noneConst); - Type zeroToNBaseType = inputIntType.getWithSizesAndDtype( - ArrayRef{numSpectrogramBinsInt, 1}, f32Ty); - Value zeroToNBase = b.create( - zeroToNBaseType, zeroToNInit, /*dim=*/oneConst); - Type zeroToNumElesType = - inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty); - Value expandShapeList = b.create( - rewriter.getType( - rewriter.getType()), - SmallVector{numSpectrogramBinsItem, numMelBinsItem}); - Value zeroToNumEles = b.create( - zeroToNumElesType, zeroToNBase, expandShapeList, - /*implicit=*/falseConst); + Torch::ValueTensorType unsqueezeIotaResType = + Torch::ValueTensorType::get(ctx, shapeNSBx1, si32Ty); + Value iota = b.create( + unsqueezeIotaResType, iotaInit, /*dim=*/oneConst); - Type maskType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty); - Value maskLowToCenterZero = - b.create(maskType, lowToCenter, zeroConst); + Value lowToCenter = b.create( + unsqueezeBinsResType, centerBins, lowBins, /*alpha=*/oneConst); + Value centerToHigh = b.create( + unsqueezeBinsResType, highBins, centerBins, /*alpha=*/oneConst); - // L2C computation - Value lowToCenterNoZero = b.create( - unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter); - Type maskL2CAfterCType = - inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty); - Value maskL2CAfterC = b.create( - maskL2CAfterCType, zeroToNumEles, centerFreqTensor); - Type maxLFResTy = - inputIntType.getWithSizesAndDtype(ArrayRef{1}, si32Ty); - Value maxLowerFreq = - b.create(maxLFResTy, lowFreqTensor); - Value maxLowerFreqItem = - getItemOp(binder, rewriter, maxLowerFreq); - Value zeroToNumElesL2C = b.create( - zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem, - zeroToNumEles); - Value upslopeDiff = b.create( - zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor, - /*alpha=*/oneConst); - Type l2cNZFltTy = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty); - Value l2cNZFlt = b.create( - l2cNZFltTy, lowToCenterNoZero, /*dtype=*/float32DTypeConst, - /*non_blocking=*/falseConst, /*copy=*/falseConst, - /*memory_format=*/noneConst); - Value upslopeL2C0 = b.create( - zeroToNumElesType, upslopeDiff, l2cNZFlt); - Type maskUpslopeL2C0PosType = - inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty); - Value maskUpslopeL2C0Pos = b.create( - maskUpslopeL2C0PosType, upslopeL2C0, zeroConst); - Value upslopeL2C0PosRanged = b.create( - zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst); - Value maskIdxL2CAfterCList = b.create( - rewriter.getType(maskL2CAfterC.getType()), - ValueRange{maskL2CAfterC}); - Value zeroConstTensor = Torch::createRank0Tensor( - rewriter, binder.getLoc(), - Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), zeroConst); - Value upslopeL2C1 = b.create( - zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList, - zeroConstTensor, falseConst); - Value maskIdxL2CZeroList = b.create( - rewriter.getType(maskLowToCenterZero.getType()), - ValueRange{maskLowToCenterZero}); - Type centerFreqTensorL2CZeroType = - inputIntType.getWithSizesAndDtype(ArrayRef{-1}, si32Ty); - Value centerFreqTensorL2CZero = b.create( - centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList); - Type maskSqueezeType = - inputIntType.getWithSizesAndDtype(shapeNMB, i1Ty); - Value maskLowToCenterZeroSqueeze = b.create( - maskSqueezeType, maskLowToCenterZero); - Type maskL2CIntTy = inputIntType.getWithSizesAndDtype(shapeNMB, si32Ty); - Value maskLowToCenterInt = b.create( - maskL2CIntTy, maskLowToCenterZeroSqueeze, /*dtype=*/int32DTypeConst, - /*non_blocking=*/falseConst, /*copy=*/falseConst, - /*memory_format=*/noneConst); - Value upslopeOneIdxList = b.create( - rewriter.getType( - centerFreqTensorL2CZero.getType()), - ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt}); Value oneConstTensor = Torch::createRank0Tensor( rewriter, binder.getLoc(), Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst); - Value upslopeL2C = b.create( - zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor, - falseConst); - // H2C computation - Value maskCenterToHighZero = - b.create(maskType, centerToHigh, zeroConst); - Value maskH2CBeforeC = b.create( - maskL2CAfterCType, zeroToNumEles, centerFreqTensor); - Value centerToHighNoZero = b.create( - unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh); - Value c2hNZFlt = b.create( - l2cNZFltTy, centerToHighNoZero, /*dtype=*/float32DTypeConst, + Type scaledType = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty); + Value upscaleInit = b.create( + unsqueezeBinsResType, oneConstTensor, lowToCenter); + Value upscale = b.create( + scaledType, upscaleInit, /*dtype=*/float32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value zeroToNumElesC2H = b.create( - zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles); - Value downslopeDiff = b.create( - zeroToNumElesType, highFreqTensor, zeroToNumElesC2H, - /*alpha=*/oneConst); - Value downslopeC2H0 = b.create( - zeroToNumElesType, downslopeDiff, c2hNZFlt); - Value maskDownslopeC2H0Pos = b.create( - maskUpslopeL2C0PosType, downslopeC2H0, zeroConst); - Value downslopeC2H0Pos = b.create( - zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst); - Value idxH2CBeforeCList = b.create( - rewriter.getType(maskH2CBeforeC.getType()), - ValueRange{maskH2CBeforeC}); - Value downslopeC2H = b.create( - zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList, - zeroConstTensor, falseConst); - // final result Calculation - Value maskH2CNonZero = b.create( - maskL2CAfterCType, downslopeC2H, zeroConst); - Value idxH2CNZList = b.create( - rewriter.getType(maskH2CNonZero.getType()), - ValueRange{maskH2CNonZero}); - Value upslopeL2CMasked = b.create( - zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor, - falseConst); + Value downscaleInit = b.create( + unsqueezeBinsResType, oneConstTensor, centerToHigh); + Value downscale = b.create( + scaledType, downscaleInit, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); - Value slopesFinal = b.create( - zeroToNumElesType, upslopeL2CMasked, downslopeC2H, - /*alpha=*/oneConst); + Torch::ValueTensorType binsDiffType = + Torch::ValueTensorType::get(ctx, shapeNSBxNMB, si32Ty); + Torch::ValueTensorType diffFloatType = + Torch::ValueTensorType::get(ctx, shapeNSBxNMB, f32Ty); + + Value iotaSubLBInt = b.create( + binsDiffType, iota, lowBins, /*alpha=*/oneConst); + Value iotaSubLB = b.create( + diffFloatType, iotaSubLBInt, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value rampUp = + b.create(diffFloatType, iotaSubLB, upscale); + + Value hbSubIotaInt = b.create( + binsDiffType, highBins, iota, /*alpha=*/oneConst); + Value hbSubIota = b.create( + diffFloatType, hbSubIotaInt, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value rampDown = b.create(diffFloatType, + hbSubIota, downscale); + + // ramp values + Type iotaCmpBinsType = + inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty); + + // Iota Cmp Bins + Value iotaGtEqCBins = + b.create(iotaCmpBinsType, iota, centerBins); + Value iotaEqCBins = + b.create(iotaCmpBinsType, iota, centerBins); + Value iotaLtLBins = + b.create(iotaCmpBinsType, iota, lowBins); + Value iotaGtLBins = + b.create(iotaCmpBinsType, iota, highBins); + + // Create output freq ramps Low-Center-High + Type rampInitType = + inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty); + Value rampInit = b.create( + rampInitType, iotaGtEqCBins, rampDown, rampUp); + Value rampInitLt = b.create( + rampInitType, iotaLtLBins, zeroConst, rampInit); + Value rampInitLtGt = b.create( + rampInitType, iotaGtLBins, zeroConst, rampInitLt); + + Type C2HCmpBinsType = + inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty); + Value C2HEqZero = b.create( + C2HCmpBinsType, centerToHigh, zeroConst); + Value cornerCases = b.create( + iotaCmpBinsType, iotaEqCBins, C2HEqZero); + Value rampOutput = b.create( + rampInitType, cornerCases, oneFltConst, rampInitLtGt); Value outputDTypeConst = b.create( rewriter.getType(), rewriter.getI64IntegerAttr(torchDTypeInt.value())); Value finalOutput = b.create( - resultType, slopesFinal, /*dtype=*/outputDTypeConst, + resultType, rampOutput, /*dtype=*/outputDTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index f291a5991..43ced2e29 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1974,113 +1974,118 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK-LABEL: func.func @test_mwm func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "test_mwm", torch.onnx_meta.producer_version = ""} { - // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>, %[[VAL_1:.*]]: !torch.vtensor<[],si64>, %[[VAL_2:.*]]: !torch.vtensor<[],si64>, - // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>, - // CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[],f32> + // CHECK-SAME: %[[NUM_MEL_BINS_ARG:.*]]: !torch.vtensor<[],si64>, %[[DFT_LENGTH_ARG:.*]]: !torch.vtensor<[],si64>, %[[SAMPLE_RATE_ARG:.*]]: !torch.vtensor<[],si64>, + // CHECK-SAME: %[[LOWER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32>, + // CHECK-SAME: %[[UPPER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32> // CHECK: %[[VAL_5:.*]] = torch.constant.none - // CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[VAL_7:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[VAL_8:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[VAL_9:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[VAL_10:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[VAL_11:.*]] = torch.constant.none - // CHECK: %[[VAL_12:.*]] = torch.constant.int -2 - // CHECK: %[[VAL_13:.*]] = torch.constant.int -1 - // CHECK: %[[VAL_14:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_15:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_16:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_17:.*]] = torch.constant.int 6 - // CHECK: %[[VAL_18:.*]] = torch.aten.div.int %[[VAL_7]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.float - // CHECK: %[[VAL_19:.*]] = torch.aten.Int.float %[[VAL_18]] : !torch.float -> !torch.int - // CHECK: %[[VAL_20:.*]] = torch.aten.add.int %[[VAL_19]], %[[VAL_15]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[VAL_21:.*]] = torch.aten.arange %[[VAL_6]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si32> - // CHECK: %[[VAL_22:.*]] = torch.constant.float 2.595000e+03 - // CHECK: %[[VAL_23:.*]] = torch.constant.float 7.000000e+02 - // CHECK: %[[VAL_24:.*]] = torch.constant.float 1.000000e+01 - // CHECK: %[[VAL_25:.*]] = torch.aten.div.float %[[VAL_9]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float - // CHECK: %[[VAL_26:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_25]] : !torch.float -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_27:.*]] = torch.aten.add.Scalar %[[VAL_26]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_28:.*]] = torch.aten.log10 %[[VAL_27]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_29:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[VAL_10]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[NUM_MEL_BINS_ITEM:.*]] = torch.aten.item %[[NUM_MEL_BINS_ARG]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SAMPLE_RATE_ITEM:.*]] = torch.aten.item %[[SAMPLE_RATE_ARG]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[LOWER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[LOWER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[UPPER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[UPPER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_10:.*]] = torch.constant.none + // CHECK: %[[VAL_11:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_13:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_14:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_15:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_16:.*]] = torch.aten.floor_divide.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_13]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[VAL_17:.*]] = torch.aten.add.Scalar %[[VAL_16]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[NUM_SPECTROGRAM_BINS_ITEM:.*]] = torch.aten.item %[[VAL_17]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL_19:.*]] = torch.constant.float 2.595000e+03 + // CHECK: %[[VAL_20:.*]] = torch.constant.float 7.000000e+02 + // CHECK: %[[VAL_21:.*]] = torch.constant.float 1.000000e+01 + // CHECK: %[[VAL_22:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[CONST_LN_TO_LOG10:.*]] = torch.constant.float 0.43429448190325182 + // CHECK: %[[VAL_24:.*]] = torch.aten.div.float %[[LOWER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[VAL_25:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_24]] : !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_26:.*]] = torch.aten.add.Scalar %[[VAL_25]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_27:.*]] = torch.aten.log %[[VAL_26]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_28:.*]] = torch.aten.mul.Scalar %[[VAL_27]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[LOW_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[UPPER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float // CHECK: %[[VAL_31:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_30]] : !torch.float -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_33:.*]] = torch.aten.log10 %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_35:.*]] = torch.aten.sub.Tensor %[[VAL_34]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_36:.*]] = torch.aten.div.Scalar %[[VAL_35]], %[[VAL_6]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_37:.*]] = torch.aten.mul.Tensor %[[VAL_21]], %[[VAL_36]] : !torch.vtensor<[10],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_38:.*]] = torch.aten.add.Tensor %[[VAL_37]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_39:.*]] = torch.aten.div.Scalar %[[VAL_38]], %[[VAL_22]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_40:.*]] = torch.aten.clone %[[VAL_38]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_41:.*]] = torch.aten.fill.Scalar %[[VAL_40]], %[[VAL_24]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_42:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_41]], %[[VAL_39]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_43:.*]] = torch.aten.sub.Scalar %[[VAL_42]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_44:.*]] = torch.aten.mul.Scalar %[[VAL_43]], %[[VAL_23]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_45:.*]] = torch.aten.add.Scalar %[[VAL_1]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: %[[VAL_46:.*]] = torch.aten.item %[[VAL_45]] : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[VAL_47:.*]] = torch.aten.mul.Scalar %[[VAL_44]], %[[VAL_46]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_48:.*]] = torch.aten.div.Scalar %[[VAL_47]], %[[VAL_8]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_49:.*]] = torch.constant.int 3 - // CHECK: %[[VAL_50:.*]] = torch.constant.bool false - // CHECK: %[[VAL_51:.*]] = torch.aten.to.dtype %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],si32> - // CHECK: %[[VAL_52:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_12]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32> - // CHECK: %[[VAL_53:.*]] = torch.aten.unsqueeze %[[VAL_52]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_54:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_15]], %[[VAL_13]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32> - // CHECK: %[[VAL_55:.*]] = torch.aten.unsqueeze %[[VAL_54]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_56:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[8],si32> - // CHECK: %[[VAL_57:.*]] = torch.aten.unsqueeze %[[VAL_56]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_58:.*]] = torch.aten.sub.Tensor %[[VAL_55]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_59:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_55]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_60:.*]] = torch.aten.arange %[[VAL_20]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],f32> - // CHECK: %[[VAL_61:.*]] = torch.aten.unsqueeze %[[VAL_60]], %[[VAL_15]] : !torch.vtensor<[9],f32>, !torch.int -> !torch.vtensor<[9,1],f32> - // CHECK: %[[VAL_62:.*]] = torch.prim.ListConstruct %[[VAL_20]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_63:.*]] = torch.aten.expand %[[VAL_61]], %[[VAL_62]], %[[VAL_50]] : !torch.vtensor<[9,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_64:.*]] = torch.aten.eq.Scalar %[[VAL_58]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1> - // CHECK: %[[VAL_65:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_13]], %[[VAL_58]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_66:.*]] = torch.aten.gt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> - // CHECK: %[[VAL_67:.*]] = torch.aten.max %[[VAL_53]] : !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1],si32> - // CHECK: %[[VAL_68:.*]] = torch.aten.item %[[VAL_67]] : !torch.vtensor<[1],si32> -> !torch.int - // CHECK: %[[VAL_69:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_68]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_70:.*]] = torch.aten.sub.Tensor %[[VAL_69]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_71:.*]] = torch.aten.to.dtype %[[VAL_65]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32> - // CHECK: %[[VAL_72:.*]] = torch.aten.div.Tensor %[[VAL_70]], %[[VAL_71]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_73:.*]] = torch.aten.gt.Scalar %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1> - // CHECK: %[[VAL_74:.*]] = torch.aten.where.ScalarOther %[[VAL_73]], %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_75:.*]] = torch.prim.ListConstruct %[[VAL_66]] : (!torch.vtensor<[9,8],i1>) -> !torch.list> - // CHECK: %[[VAL_76:.*]] = torch.prim.ListConstruct : () -> !torch.list - // CHECK: %[[VAL_77:.*]] = torch.constant.none - // CHECK: %[[VAL_78:.*]] = torch.constant.int 6 - // CHECK: %[[VAL_79:.*]] = torch.aten.full %[[VAL_76]], %[[VAL_14]], %[[VAL_78]], %[[VAL_77]], %[[VAL_77]], %[[VAL_77]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_80:.*]] = torch.aten.index_put %[[VAL_74]], %[[VAL_75]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_81:.*]] = torch.prim.ListConstruct %[[VAL_64]] : (!torch.vtensor<[1,8],i1>) -> !torch.list> - // CHECK: %[[VAL_82:.*]] = torch.aten.index.Tensor %[[VAL_55]], %[[VAL_81]] : !torch.vtensor<[1,8],si32>, !torch.list> -> !torch.vtensor<[?],si32> - // CHECK: %[[VAL_83:.*]] = torch.aten.squeeze %[[VAL_64]] : !torch.vtensor<[1,8],i1> -> !torch.vtensor<[8],i1> - // CHECK: %[[VAL_84:.*]] = torch.aten.to.dtype %[[VAL_83]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[8],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> - // CHECK: %[[VAL_85:.*]] = torch.prim.ListConstruct %[[VAL_82]], %[[VAL_84]] : (!torch.vtensor<[?],si32>, !torch.vtensor<[8],si32>) -> !torch.list> - // CHECK: %[[VAL_86:.*]] = torch.prim.ListConstruct : () -> !torch.list - // CHECK: %[[VAL_87:.*]] = torch.constant.none - // CHECK: %[[VAL_88:.*]] = torch.constant.int 6 - // CHECK: %[[VAL_89:.*]] = torch.aten.full %[[VAL_86]], %[[VAL_15]], %[[VAL_88]], %[[VAL_87]], %[[VAL_87]], %[[VAL_87]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_90:.*]] = torch.aten.index_put %[[VAL_80]], %[[VAL_85]], %[[VAL_89]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_91:.*]] = torch.aten.eq.Scalar %[[VAL_59]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1> - // CHECK: %[[VAL_92:.*]] = torch.aten.lt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> - // CHECK: %[[VAL_93:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_13]], %[[VAL_59]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32> - // CHECK: %[[VAL_94:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32> - // CHECK: %[[VAL_95:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_14]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_96:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_95]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_97:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[VAL_94]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_98:.*]] = torch.aten.gt.Scalar %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1> - // CHECK: %[[VAL_99:.*]] = torch.aten.where.ScalarOther %[[VAL_98]], %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_100:.*]] = torch.prim.ListConstruct %[[VAL_92]] : (!torch.vtensor<[9,8],i1>) -> !torch.list> - // CHECK: %[[VAL_101:.*]] = torch.aten.index_put %[[VAL_99]], %[[VAL_100]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_102:.*]] = torch.aten.ne.Scalar %[[VAL_101]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1> - // CHECK: %[[VAL_103:.*]] = torch.prim.ListConstruct %[[VAL_102]] : (!torch.vtensor<[9,8],i1>) -> !torch.list> - // CHECK: %[[VAL_104:.*]] = torch.aten.index_put %[[VAL_90]], %[[VAL_103]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_105:.*]] = torch.aten.add.Tensor %[[VAL_104]], %[[VAL_101]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32> - // CHECK: %[[VAL_106:.*]] = torch.constant.int 6 - // CHECK: %[[VAL_107:.*]] = torch.aten.to.dtype %[[VAL_105]], %[[VAL_106]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> - // CHECK: return %[[VAL_107]] : !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_33:.*]] = torch.aten.log %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[HIGH_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_34]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_36:.*]] = torch.aten.sub.Tensor %[[HIGH_FREQ_MEL]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_37:.*]] = torch.aten.add.int %[[NUM_MEL_BINS_ITEM]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[MEL_STEP:.*]] = torch.aten.div.Scalar %[[VAL_36]], %[[VAL_37]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[LOW_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[CENTER_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[HIGH_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[VAL_42:.*]] = torch.aten.add.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[VAL_43:.*]] = torch.aten.item %[[VAL_42]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL_44:.*]] = torch.constant.bool false + // CHECK: %[[VAL_45:.*]] = torch.aten.mul.Tensor %[[LOW_BINS_INIT]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_46:.*]] = torch.aten.add.Tensor %[[VAL_45]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_47:.*]] = torch.aten.div.Scalar %[[VAL_46]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_48:.*]] = torch.aten.clone %[[VAL_46]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_49:.*]] = torch.aten.fill.Scalar %[[VAL_48]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_50:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_49]], %[[VAL_47]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_51:.*]] = torch.aten.sub.Scalar %[[VAL_50]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_52:.*]] = torch.aten.mul.Scalar %[[VAL_51]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_53:.*]] = torch.aten.mul.Scalar %[[VAL_52]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_54:.*]] = torch.aten.div.Scalar %[[VAL_53]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_55:.*]] = torch.aten.to.dtype %[[VAL_54]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[LOW_BINS:.*]] = torch.aten.unsqueeze %[[VAL_55]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[VAL_57:.*]] = torch.aten.add.Scalar %[[CENTER_BINS_INIT]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32> + // CHECK: %[[VAL_58:.*]] = torch.aten.mul.Tensor %[[VAL_57]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_59:.*]] = torch.aten.add.Tensor %[[VAL_58]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_60:.*]] = torch.aten.div.Scalar %[[VAL_59]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_61:.*]] = torch.aten.clone %[[VAL_59]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_62:.*]] = torch.aten.fill.Scalar %[[VAL_61]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_63:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_62]], %[[VAL_60]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_64:.*]] = torch.aten.sub.Scalar %[[VAL_63]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_65:.*]] = torch.aten.mul.Scalar %[[VAL_64]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_66:.*]] = torch.aten.mul.Scalar %[[VAL_65]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_67:.*]] = torch.aten.div.Scalar %[[VAL_66]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_68:.*]] = torch.aten.to.dtype %[[VAL_67]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[CENTER_BINS:.*]] = torch.aten.unsqueeze %[[VAL_68]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[VAL_70:.*]] = torch.aten.add.Scalar %[[HIGH_BINS_INIT]], %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32> + // CHECK: %[[VAL_71:.*]] = torch.aten.mul.Tensor %[[VAL_70]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_72:.*]] = torch.aten.add.Tensor %[[VAL_71]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_73:.*]] = torch.aten.div.Scalar %[[VAL_72]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_74:.*]] = torch.aten.clone %[[VAL_72]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_75:.*]] = torch.aten.fill.Scalar %[[VAL_74]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_76:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_75]], %[[VAL_73]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_77:.*]] = torch.aten.sub.Scalar %[[VAL_76]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_78:.*]] = torch.aten.mul.Scalar %[[VAL_77]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_79:.*]] = torch.aten.mul.Scalar %[[VAL_78]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_80:.*]] = torch.aten.div.Scalar %[[VAL_79]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_81:.*]] = torch.aten.to.dtype %[[VAL_80]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[HIGH_BINS:.*]] = torch.aten.unsqueeze %[[VAL_81]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[IOTA_INIT:.*]] = torch.aten.arange %[[NUM_SPECTROGRAM_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],si32> + // CHECK: %[[IOTA:.*]] = torch.aten.unsqueeze %[[IOTA_INIT]], %[[VAL_12]] : !torch.vtensor<[9],si32>, !torch.int -> !torch.vtensor<[9,1],si32> + // CHECK: %[[LOW_TO_CENTER:.*]] = torch.aten.sub.Tensor %[[CENTER_BINS]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[CENTER_TO_HIGH:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[CENTER_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[VAL_87:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[VAL_88:.*]] = torch.constant.none + // CHECK: %[[VAL_89:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_90:.*]] = torch.aten.full %[[VAL_87]], %[[VAL_12]], %[[VAL_89]], %[[VAL_88]], %[[VAL_88]], %[[VAL_88]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_91:.*]] = torch.aten.maximum %[[VAL_90]], %[[LOW_TO_CENTER]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32> + // CHECK: %[[UP_SCALE:.*]] = torch.aten.to.dtype %[[VAL_91]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32> + // CHECK: %[[VAL_93:.*]] = torch.aten.maximum %[[VAL_90]], %[[CENTER_TO_HIGH]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32> + // CHECK: %[[DOWN_SCALE:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32> + // CHECK: %[[VAL_95:.*]] = torch.aten.sub.Tensor %[[IOTA]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],si32> + // CHECK: %[[VAL_96:.*]] = torch.aten.to.dtype %[[VAL_95]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> + // CHECK: %[[RAMP_UP:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[UP_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_98:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[IOTA]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,1],si32>, !torch.int -> !torch.vtensor<[9,8],si32> + // CHECK: %[[VAL_99:.*]] = torch.aten.to.dtype %[[VAL_98]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> + // CHECK: %[[RAMP_DOWN:.*]] = torch.aten.div.Tensor %[[VAL_99]], %[[DOWN_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_101:.*]] = torch.aten.ge.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[VAL_102:.*]] = torch.aten.eq.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[VAL_103:.*]] = torch.aten.lt.Tensor %[[IOTA]], %[[LOW_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[VAL_104:.*]] = torch.aten.gt.Tensor %[[IOTA]], %[[HIGH_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[RAMP_INIT:.*]] = torch.aten.where.self %[[VAL_101]], %[[RAMP_DOWN]], %[[RAMP_UP]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_106:.*]] = torch.aten.where.ScalarSelf %[[VAL_103]], %[[VAL_11]], %[[RAMP_INIT]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_107:.*]] = torch.aten.where.ScalarSelf %[[VAL_104]], %[[VAL_11]], %[[VAL_106]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_108:.*]] = torch.aten.eq.Scalar %[[CENTER_TO_HIGH]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1> + // CHECK: %[[CORNER_CASES:.*]] = torch.aten.logical_and %[[VAL_102]], %[[VAL_108]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[1,8],i1> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[RAMP:.*]] = torch.aten.where.ScalarSelf %[[CORNER_CASES]], %[[VAL_22]], %[[VAL_107]] : !torch.vtensor<[9,8],i1>, !torch.float, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_111:.*]] = torch.constant.int 6 + // CHECK: %[[OUTPUT:.*]] = torch.aten.to.dtype %[[RAMP]], %[[VAL_111]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> + // CHECK: return %[[OUTPUT]] : !torch.vtensor<[9,8],f32> %none = torch.constant.none %0 = torch.operator "onnx.MelWeightMatrix"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> return %0 : !torch.vtensor<[9,8],f32>