onnx.MelWeightMatrix Onnx to Torch to Linalg (#3659)

- This PR adds new (and equivalent) more tensorized impl of
MelWeightMatrix which lowers all the way to linalg.
- [Ref Pytorch
Impl](https://gist.github.com/PhaneeshB/4e6dfcded3007b1b686fbe28f07a67cd)
- Thanks to @rsuderman for pointing out the difficulties [earlier
impl](#3503) posed during lowering to linalg and also for providing a
better numpy impl 🙏
pull/3636/merge
Phaneesh Barwaria 2024-08-22 08:55:03 -07:00 committed by GitHub
parent fcc5f444cd
commit 9a6fe58a02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 320 additions and 309 deletions

View File

@ -640,8 +640,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value numMelBinsItem =
getItemOp<Torch::IntType>(binder, rewriter, operands[0]);
Value dftLengthItem =
getItemOp<Torch::IntType>(binder, rewriter, operands[1]);
Value sampleRateItem =
getItemOp<Torch::IntType>(binder, rewriter, operands[2]);
Value lowerEdgeHzItem =
@ -656,9 +654,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// Recurring shapes
SmallVector<int64_t> unranked({});
SmallVector<int64_t> shapeNMB({numMelBinsInt});
SmallVector<int64_t> shapeNMBp2({numMelBinsInt + 2});
SmallVector<int64_t> shape1xNMB({1, numMelBinsInt});
SmallVector<int64_t> shapeNSB({numSpectrogramBinsInt});
SmallVector<int64_t> shapeNSBx1({numSpectrogramBinsInt, 1});
SmallVector<int64_t> shapeNSBxNMB(
{numSpectrogramBinsInt, numMelBinsInt});
@ -671,37 +669,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// Value constants
Value noneConst = b.create<Torch::ConstantNoneOp>();
Value negTwoConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-2));
Value negOneConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-1));
Value zeroConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(0));
Value oneConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(1));
Value twoConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(2));
Value int32DTypeConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
Value float32DTypeConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(6));
Torch::ValueTensorType dftLenType =
Torch::ValueTensorType::get(ctx, unranked, inpIntDType);
Type freqBinsIntType =
Torch::ValueTensorType::get(ctx, shapeNMBp2, si32Ty);
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
Type freqBinsFltType =
Torch::ValueTensorType::get(ctx, shapeNMBp2, f32Ty);
Torch::ValueTensorType::get(ctx, shapeNMB, f32Ty);
Value dftLengthDivTwoFlt =
b.create<Torch::AtenDivIntOp>(dftLengthItem, twoConst);
Value dftLengthDivTwo =
b.create<Torch::AtenIntFloatOp>(dftLengthDivTwoFlt);
Value numSpectrogramBins =
b.create<Torch::AtenAddIntOp>(dftLengthDivTwo, oneConst);
Value numSpectrogramBinsItem = numSpectrogramBins;
Value freqBinsInit = b.create<Torch::AtenArangeOp>(
freqBinsIntType, numMelBinsItem, /*dtype=*/float32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
Value dftLengthDivTwoTensor = b.create<Torch::AtenFloorDivideScalarOp>(
dftLenType, operands[1], twoConst);
Value numSpectrogramBinsTensor = b.create<Torch::AtenAddScalarOp>(
dftLenType, dftLengthDivTwoTensor, oneConst, /*alpha =*/oneConst);
Value numSpectrogramBinsItem = getItemOp<Torch::IntType>(
binder, rewriter, numSpectrogramBinsTensor);
// From Ref Impl of Onnx.MelWeightMatrix:
// https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32
@ -712,6 +703,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(700));
Value tenConst =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(10));
Value oneFltConst =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1));
Value LnToLog10Const = b.create<Torch::ConstantFloatOp>(
rewriter.getF64FloatAttr(M_LOG10E));
Value lfDiv7Hfloat =
b.create<Torch::AtenDivFloatOp>(lowerEdgeHzItem, sevenHConst);
@ -720,8 +715,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
b.create<Torch::PrimNumToTensorScalarOp>(freqType, lfDiv7Hfloat);
Value lfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
freqType, lfDiv7H, oneConst, /*alpha =*/oneConst);
Value lfDiv7HAdd1Log10 =
b.create<Torch::AtenLog10Op>(freqType, lfDiv7HAdd1);
Value lfDiv7HAdd1Ln = b.create<Torch::AtenLogOp>(freqType, lfDiv7HAdd1);
Value lfDiv7HAdd1Log10 = b.create<Torch::AtenMulScalarOp>(
freqType, lfDiv7HAdd1Ln, LnToLog10Const);
Value lfMel = b.create<Torch::AtenMulScalarOp>(
freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst);
@ -731,226 +728,235 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
b.create<Torch::PrimNumToTensorScalarOp>(freqType, hfDiv7Hfloat);
Value hfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
freqType, hfDiv7H, oneConst, /*alpha =*/oneConst);
Value hfDiv7HAdd1Log10 =
b.create<Torch::AtenLog10Op>(freqType, hfDiv7HAdd1);
Value hfDiv7HAdd1Ln = b.create<Torch::AtenLogOp>(freqType, hfDiv7HAdd1);
Value hfDiv7HAdd1Log10 = b.create<Torch::AtenMulScalarOp>(
freqType, hfDiv7HAdd1Ln, LnToLog10Const);
Value hfMel = b.create<Torch::AtenMulScalarOp>(
freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst);
Value hfSubLf = b.create<Torch::AtenSubTensorOp>(
hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst);
Value numMelBinsPlus2 =
b.create<Torch::AtenAddIntOp>(numMelBinsItem, twoConst);
Value melStep = b.create<Torch::AtenDivScalarOp>(
hfSubLf.getType(), hfSubLf, numMelBinsItem);
hfSubLf.getType(), hfSubLf, numMelBinsPlus2);
Value freqBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
freqBinsFltType, freqBinsInit, melStep);
Value freqBinsScaled = b.create<Torch::AtenAddTensorOp>(
freqBinsFltType, freqBinsMulMelStep, lfMel, /*alpha=*/oneConst);
Value lowBinsInit = b.create<Torch::AtenArangeOp>(
freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
// Mel to Hz conv
Value centerBinsInit = b.create<Torch::AtenArangeOp>(
freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
Value fbDiv = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst);
Value fbClone = b.create<Torch::AtenCloneOp>(
freqBinsFltType, freqBinsScaled, /*memory_format=*/noneConst);
Value tenTensor = b.create<Torch::AtenFillScalarOp>(freqBinsFltType,
fbClone, tenConst);
Value fbPow = b.create<Torch::AtenPowTensorTensorOp>(freqBinsFltType,
tenTensor, fbDiv);
Value fbPowSubOne = b.create<Torch::AtenSubScalarOp>(
freqBinsFltType, fbPow, oneConst, /*alpha=*/oneConst);
Value freqBinsHz = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, fbPowSubOne, sevenHConst);
Value highBinsInit = b.create<Torch::AtenArangeOp>(
freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
// Normalize freqBinsHz
// Common values used in conversion
Value dftLenPlusOne = b.create<Torch::AtenAddScalarOp>(
dftLenType, operands[1], oneConst, /*alpha=*/oneConst);
Value dftLenPlusOneItem =
getItemOp<Torch::IntType>(binder, rewriter, dftLenPlusOne);
Value fbMulDft = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, freqBinsHz, dftLenPlusOneItem);
Value freqBinsNormalized = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, fbMulDft, sampleRateItem);
// cast to int32
Value int32DTypeConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
Value falseConst = b.create<Torch::ConstantBoolOp>(false);
Value freqBins = b.create<Torch::AtenToDtypeOp>(
freqBinsIntType, freqBinsNormalized, /*dtype=*/int32DTypeConst,
Torch::ValueTensorType unsqueezeBinsResType =
Torch::ValueTensorType::get(ctx, shape1xNMB, si32Ty);
// Low bins Mel to hz
Value lowBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
freqBinsFltType, lowBinsInit, melStep);
Value lowBinsScaled = b.create<Torch::AtenAddTensorOp>(
freqBinsFltType, lowBinsMulMelStep, lfMel, /*alpha=*/oneConst);
Value lbDiv = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, lowBinsScaled, twoFiveNineFiveConst);
Value lbClone = b.create<Torch::AtenCloneOp>(
freqBinsFltType, lowBinsScaled, /*memory_format=*/noneConst);
Value lbTenTensor = b.create<Torch::AtenFillScalarOp>(
freqBinsFltType, lbClone, tenConst);
Value lbPow = b.create<Torch::AtenPowTensorTensorOp>(
freqBinsFltType, lbTenTensor, lbDiv);
Value lbPowSubOne = b.create<Torch::AtenSubScalarOp>(
freqBinsFltType, lbPow, oneConst, /*alpha=*/oneConst);
Value lowBinsHz = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, lbPowSubOne, sevenHConst);
// Normalize freqBinsHz
Value lbMulDft = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, lowBinsHz, dftLenPlusOneItem);
Value lowBinsNormalized = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, lbMulDft, sampleRateItem);
// cast to int32
Value lowBinsInt = b.create<Torch::AtenToDtypeOp>(
freqBinsIntType, lowBinsNormalized, /*dtype=*/int32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value lowBins = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeBinsResType, lowBinsInt, /*dim=*/zeroConst);
Torch::ValueTensorType sliceResType =
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
Type unsqueezeResType =
sliceResType.getWithSizesAndDtype(shape1xNMB, si32Ty);
Value lfTensor = b.create<Torch::AtenSliceTensorOp>(
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
/*end=*/negTwoConst, /*step=*/oneConst);
Value lowFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeResType, lfTensor, /*dim=*/zeroConst);
// Center bins mel to hz
Value centerBinsInitInc = b.create<Torch::AtenAddScalarOp>(
freqBinsIntType, centerBinsInit, oneConst, /*alpha=*/oneConst);
Value centerBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
freqBinsFltType, centerBinsInitInc, melStep);
Value centerBinsScaled = b.create<Torch::AtenAddTensorOp>(
freqBinsFltType, centerBinsMulMelStep, lfMel, /*alpha=*/oneConst);
Value cbDiv = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, centerBinsScaled, twoFiveNineFiveConst);
Value cbClone = b.create<Torch::AtenCloneOp>(
freqBinsFltType, centerBinsScaled, /*memory_format=*/noneConst);
Value cbTenTensor = b.create<Torch::AtenFillScalarOp>(
freqBinsFltType, cbClone, tenConst);
Value cbPow = b.create<Torch::AtenPowTensorTensorOp>(
freqBinsFltType, cbTenTensor, cbDiv);
Value cbPowSubOne = b.create<Torch::AtenSubScalarOp>(
freqBinsFltType, cbPow, oneConst, /*alpha=*/oneConst);
Value centerBinsHz = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, cbPowSubOne, sevenHConst);
// Normalize freqBinsHz
Value cbMulDft = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, centerBinsHz, dftLenPlusOneItem);
Value centerBinsNormalized = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, cbMulDft, sampleRateItem);
// cast to int32
Value centerBinsInt = b.create<Torch::AtenToDtypeOp>(
freqBinsIntType, centerBinsNormalized, /*dtype=*/int32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value centerBins = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeBinsResType, centerBinsInt, /*dim=*/zeroConst);
Value cfTensor = b.create<Torch::AtenSliceTensorOp>(
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/oneConst,
/*end=*/negOneConst, /*step=*/oneConst);
Value centerFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeResType, cfTensor, /*dim=*/zeroConst);
// High bins mel to hz
Value highBinsInitInc = b.create<Torch::AtenAddScalarOp>(
freqBinsIntType, highBinsInit, twoConst, /*alpha=*/oneConst);
Value highBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
freqBinsFltType, highBinsInitInc, melStep);
Value highBinsScaled = b.create<Torch::AtenAddTensorOp>(
freqBinsFltType, highBinsMulMelStep, lfMel, /*alpha=*/oneConst);
Value hbDiv = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, highBinsScaled, twoFiveNineFiveConst);
Value hbClone = b.create<Torch::AtenCloneOp>(
freqBinsFltType, highBinsScaled, /*memory_format=*/noneConst);
Value hbTenTensor = b.create<Torch::AtenFillScalarOp>(
freqBinsFltType, hbClone, tenConst);
Value hbPow = b.create<Torch::AtenPowTensorTensorOp>(
freqBinsFltType, hbTenTensor, hbDiv);
Value hbPowSubOne = b.create<Torch::AtenSubScalarOp>(
freqBinsFltType, hbPow, oneConst, /*alpha=*/oneConst);
Value highBinsHz = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, hbPowSubOne, sevenHConst);
// Normalize freqBinsHz
Value hbMulDft = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, highBinsHz, dftLenPlusOneItem);
Value highBinsNormalized = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, hbMulDft, sampleRateItem);
// cast to int32
Value highBinsInt = b.create<Torch::AtenToDtypeOp>(
freqBinsIntType, highBinsNormalized, /*dtype=*/int32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value highBins = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeBinsResType, highBinsInt, /*dim=*/zeroConst);
Value hfTensor = b.create<Torch::AtenSliceTensorOp>(
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
/*end=*/noneConst, /*step=*/oneConst);
Value highFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeResType, hfTensor, /*dim=*/zeroConst);
Value lowToCenter =
b.create<Torch::AtenSubTensorOp>(unsqueezeResType, centerFreqTensor,
lowFreqTensor, /*alpha=*/oneConst);
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
unsqueezeResType, highFreqTensor, centerFreqTensor,
/*alpha=*/oneConst);
Type zeroToNInitType =
inputIntType.getWithSizesAndDtype(shapeNSB, f32Ty);
Value zeroToNInit = b.create<Torch::AtenArangeOp>(
zeroToNInitType, numSpectrogramBinsItem,
/*dtype=*/float32DTypeConst,
Type iotaInitType = inputIntType.getWithSizesAndDtype(shapeNSB, si32Ty);
Value iotaInit = b.create<Torch::AtenArangeOp>(
iotaInitType, numSpectrogramBinsItem,
/*dtype=*/int32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
Type zeroToNBaseType = inputIntType.getWithSizesAndDtype(
ArrayRef<int64_t>{numSpectrogramBinsInt, 1}, f32Ty);
Value zeroToNBase = b.create<Torch::AtenUnsqueezeOp>(
zeroToNBaseType, zeroToNInit, /*dim=*/oneConst);
Type zeroToNumElesType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
Value expandShapeList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{numSpectrogramBinsItem, numMelBinsItem});
Value zeroToNumEles = b.create<Torch::AtenExpandOp>(
zeroToNumElesType, zeroToNBase, expandShapeList,
/*implicit=*/falseConst);
Torch::ValueTensorType unsqueezeIotaResType =
Torch::ValueTensorType::get(ctx, shapeNSBx1, si32Ty);
Value iota = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeIotaResType, iotaInit, /*dim=*/oneConst);
Type maskType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
Value maskLowToCenterZero =
b.create<Torch::AtenEqScalarOp>(maskType, lowToCenter, zeroConst);
Value lowToCenter = b.create<Torch::AtenSubTensorOp>(
unsqueezeBinsResType, centerBins, lowBins, /*alpha=*/oneConst);
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
unsqueezeBinsResType, highBins, centerBins, /*alpha=*/oneConst);
// L2C computation
Value lowToCenterNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter);
Type maskL2CAfterCType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
Value maskL2CAfterC = b.create<Torch::AtenGtTensorOp>(
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
Type maxLFResTy =
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{1}, si32Ty);
Value maxLowerFreq =
b.create<Torch::AtenMaxOp>(maxLFResTy, lowFreqTensor);
Value maxLowerFreqItem =
getItemOp<Torch::IntType>(binder, rewriter, maxLowerFreq);
Value zeroToNumElesL2C = b.create<Torch::AtenWhereScalarSelfOp>(
zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem,
zeroToNumEles);
Value upslopeDiff = b.create<Torch::AtenSubTensorOp>(
zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor,
/*alpha=*/oneConst);
Type l2cNZFltTy = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
Value l2cNZFlt = b.create<Torch::AtenToDtypeOp>(
l2cNZFltTy, lowToCenterNoZero, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value upslopeL2C0 = b.create<Torch::AtenDivTensorOp>(
zeroToNumElesType, upslopeDiff, l2cNZFlt);
Type maskUpslopeL2C0PosType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
Value maskUpslopeL2C0Pos = b.create<Torch::AtenGtScalarOp>(
maskUpslopeL2C0PosType, upslopeL2C0, zeroConst);
Value upslopeL2C0PosRanged = b.create<Torch::AtenWhereScalarOtherOp>(
zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst);
Value maskIdxL2CAfterCList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskL2CAfterC.getType()),
ValueRange{maskL2CAfterC});
Value zeroConstTensor = Torch::createRank0Tensor(
rewriter, binder.getLoc(),
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), zeroConst);
Value upslopeL2C1 = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList,
zeroConstTensor, falseConst);
Value maskIdxL2CZeroList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskLowToCenterZero.getType()),
ValueRange{maskLowToCenterZero});
Type centerFreqTensorL2CZeroType =
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{-1}, si32Ty);
Value centerFreqTensorL2CZero = b.create<Torch::AtenIndexTensorOp>(
centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList);
Type maskSqueezeType =
inputIntType.getWithSizesAndDtype(shapeNMB, i1Ty);
Value maskLowToCenterZeroSqueeze = b.create<Torch::AtenSqueezeOp>(
maskSqueezeType, maskLowToCenterZero);
Type maskL2CIntTy = inputIntType.getWithSizesAndDtype(shapeNMB, si32Ty);
Value maskLowToCenterInt = b.create<Torch::AtenToDtypeOp>(
maskL2CIntTy, maskLowToCenterZeroSqueeze, /*dtype=*/int32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value upslopeOneIdxList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(
centerFreqTensorL2CZero.getType()),
ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt});
Value oneConstTensor = Torch::createRank0Tensor(
rewriter, binder.getLoc(),
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst);
Value upslopeL2C = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor,
falseConst);
// H2C computation
Value maskCenterToHighZero =
b.create<Torch::AtenEqScalarOp>(maskType, centerToHigh, zeroConst);
Value maskH2CBeforeC = b.create<Torch::AtenLtTensorOp>(
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
Value centerToHighNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh);
Value c2hNZFlt = b.create<Torch::AtenToDtypeOp>(
l2cNZFltTy, centerToHighNoZero, /*dtype=*/float32DTypeConst,
Type scaledType = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
Value upscaleInit = b.create<Torch::AtenMaximumOp>(
unsqueezeBinsResType, oneConstTensor, lowToCenter);
Value upscale = b.create<Torch::AtenToDtypeOp>(
scaledType, upscaleInit, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value zeroToNumElesC2H = b.create<Torch::AtenWhereScalarSelfOp>(
zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles);
Value downslopeDiff = b.create<Torch::AtenSubTensorOp>(
zeroToNumElesType, highFreqTensor, zeroToNumElesC2H,
/*alpha=*/oneConst);
Value downslopeC2H0 = b.create<Torch::AtenDivTensorOp>(
zeroToNumElesType, downslopeDiff, c2hNZFlt);
Value maskDownslopeC2H0Pos = b.create<Torch::AtenGtScalarOp>(
maskUpslopeL2C0PosType, downslopeC2H0, zeroConst);
Value downslopeC2H0Pos = b.create<Torch::AtenWhereScalarOtherOp>(
zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst);
Value idxH2CBeforeCList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskH2CBeforeC.getType()),
ValueRange{maskH2CBeforeC});
Value downslopeC2H = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList,
zeroConstTensor, falseConst);
// final result Calculation
Value maskH2CNonZero = b.create<Torch::AtenNeScalarOp>(
maskL2CAfterCType, downslopeC2H, zeroConst);
Value idxH2CNZList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskH2CNonZero.getType()),
ValueRange{maskH2CNonZero});
Value upslopeL2CMasked = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor,
falseConst);
Value downscaleInit = b.create<Torch::AtenMaximumOp>(
unsqueezeBinsResType, oneConstTensor, centerToHigh);
Value downscale = b.create<Torch::AtenToDtypeOp>(
scaledType, downscaleInit, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value slopesFinal = b.create<Torch::AtenAddTensorOp>(
zeroToNumElesType, upslopeL2CMasked, downslopeC2H,
/*alpha=*/oneConst);
Torch::ValueTensorType binsDiffType =
Torch::ValueTensorType::get(ctx, shapeNSBxNMB, si32Ty);
Torch::ValueTensorType diffFloatType =
Torch::ValueTensorType::get(ctx, shapeNSBxNMB, f32Ty);
Value iotaSubLBInt = b.create<Torch::AtenSubTensorOp>(
binsDiffType, iota, lowBins, /*alpha=*/oneConst);
Value iotaSubLB = b.create<Torch::AtenToDtypeOp>(
diffFloatType, iotaSubLBInt, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value rampUp =
b.create<Torch::AtenDivTensorOp>(diffFloatType, iotaSubLB, upscale);
Value hbSubIotaInt = b.create<Torch::AtenSubTensorOp>(
binsDiffType, highBins, iota, /*alpha=*/oneConst);
Value hbSubIota = b.create<Torch::AtenToDtypeOp>(
diffFloatType, hbSubIotaInt, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value rampDown = b.create<Torch::AtenDivTensorOp>(diffFloatType,
hbSubIota, downscale);
// ramp values
Type iotaCmpBinsType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
// Iota Cmp Bins
Value iotaGtEqCBins =
b.create<Torch::AtenGeTensorOp>(iotaCmpBinsType, iota, centerBins);
Value iotaEqCBins =
b.create<Torch::AtenEqTensorOp>(iotaCmpBinsType, iota, centerBins);
Value iotaLtLBins =
b.create<Torch::AtenLtTensorOp>(iotaCmpBinsType, iota, lowBins);
Value iotaGtLBins =
b.create<Torch::AtenGtTensorOp>(iotaCmpBinsType, iota, highBins);
// Create output freq ramps Low-Center-High
Type rampInitType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
Value rampInit = b.create<Torch::AtenWhereSelfOp>(
rampInitType, iotaGtEqCBins, rampDown, rampUp);
Value rampInitLt = b.create<Torch::AtenWhereScalarSelfOp>(
rampInitType, iotaLtLBins, zeroConst, rampInit);
Value rampInitLtGt = b.create<Torch::AtenWhereScalarSelfOp>(
rampInitType, iotaGtLBins, zeroConst, rampInitLt);
Type C2HCmpBinsType =
inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
Value C2HEqZero = b.create<Torch::AtenEqScalarOp>(
C2HCmpBinsType, centerToHigh, zeroConst);
Value cornerCases = b.create<Torch::AtenLogicalAndOp>(
iotaCmpBinsType, iotaEqCBins, C2HEqZero);
Value rampOutput = b.create<Torch::AtenWhereScalarSelfOp>(
rampInitType, cornerCases, oneFltConst, rampInitLtGt);
Value outputDTypeConst = b.create<Torch::ConstantIntOp>(
rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(torchDTypeInt.value()));
Value finalOutput = b.create<Torch::AtenToDtypeOp>(
resultType, slopesFinal, /*dtype=*/outputDTypeConst,
resultType, rampOutput, /*dtype=*/outputDTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);

View File

@ -1974,113 +1974,118 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
// CHECK-LABEL: func.func @test_mwm
func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "test_mwm", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>, %[[VAL_1:.*]]: !torch.vtensor<[],si64>, %[[VAL_2:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>,
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[],f32>
// CHECK-SAME: %[[NUM_MEL_BINS_ARG:.*]]: !torch.vtensor<[],si64>, %[[DFT_LENGTH_ARG:.*]]: !torch.vtensor<[],si64>, %[[SAMPLE_RATE_ARG:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME: %[[LOWER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32>,
// CHECK-SAME: %[[UPPER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32>
// CHECK: %[[VAL_5:.*]] = torch.constant.none
// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_7:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_9:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_10:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_11:.*]] = torch.constant.none
// CHECK: %[[VAL_12:.*]] = torch.constant.int -2
// CHECK: %[[VAL_13:.*]] = torch.constant.int -1
// CHECK: %[[VAL_14:.*]] = torch.constant.int 0
// CHECK: %[[VAL_15:.*]] = torch.constant.int 1
// CHECK: %[[VAL_16:.*]] = torch.constant.int 2
// CHECK: %[[VAL_17:.*]] = torch.constant.int 6
// CHECK: %[[VAL_18:.*]] = torch.aten.div.int %[[VAL_7]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[VAL_19:.*]] = torch.aten.Int.float %[[VAL_18]] : !torch.float -> !torch.int
// CHECK: %[[VAL_20:.*]] = torch.aten.add.int %[[VAL_19]], %[[VAL_15]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAL_21:.*]] = torch.aten.arange %[[VAL_6]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si32>
// CHECK: %[[VAL_22:.*]] = torch.constant.float 2.595000e+03
// CHECK: %[[VAL_23:.*]] = torch.constant.float 7.000000e+02
// CHECK: %[[VAL_24:.*]] = torch.constant.float 1.000000e+01
// CHECK: %[[VAL_25:.*]] = torch.aten.div.float %[[VAL_9]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[VAL_26:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_25]] : !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_27:.*]] = torch.aten.add.Scalar %[[VAL_26]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_28:.*]] = torch.aten.log10 %[[VAL_27]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_29:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[VAL_10]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[NUM_MEL_BINS_ITEM:.*]] = torch.aten.item %[[NUM_MEL_BINS_ARG]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[SAMPLE_RATE_ITEM:.*]] = torch.aten.item %[[SAMPLE_RATE_ARG]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[LOWER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[LOWER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[UPPER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[UPPER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_10:.*]] = torch.constant.none
// CHECK: %[[VAL_11:.*]] = torch.constant.int 0
// CHECK: %[[VAL_12:.*]] = torch.constant.int 1
// CHECK: %[[VAL_13:.*]] = torch.constant.int 2
// CHECK: %[[VAL_14:.*]] = torch.constant.int 3
// CHECK: %[[VAL_15:.*]] = torch.constant.int 6
// CHECK: %[[VAL_16:.*]] = torch.aten.floor_divide.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_13]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_17:.*]] = torch.aten.add.Scalar %[[VAL_16]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[NUM_SPECTROGRAM_BINS_ITEM:.*]] = torch.aten.item %[[VAL_17]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_19:.*]] = torch.constant.float 2.595000e+03
// CHECK: %[[VAL_20:.*]] = torch.constant.float 7.000000e+02
// CHECK: %[[VAL_21:.*]] = torch.constant.float 1.000000e+01
// CHECK: %[[VAL_22:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[CONST_LN_TO_LOG10:.*]] = torch.constant.float 0.43429448190325182
// CHECK: %[[VAL_24:.*]] = torch.aten.div.float %[[LOWER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[VAL_25:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_24]] : !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_26:.*]] = torch.aten.add.Scalar %[[VAL_25]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_27:.*]] = torch.aten.log %[[VAL_26]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_28:.*]] = torch.aten.mul.Scalar %[[VAL_27]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[LOW_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[UPPER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[VAL_31:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_30]] : !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_33:.*]] = torch.aten.log10 %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_35:.*]] = torch.aten.sub.Tensor %[[VAL_34]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_36:.*]] = torch.aten.div.Scalar %[[VAL_35]], %[[VAL_6]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_37:.*]] = torch.aten.mul.Tensor %[[VAL_21]], %[[VAL_36]] : !torch.vtensor<[10],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_38:.*]] = torch.aten.add.Tensor %[[VAL_37]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_39:.*]] = torch.aten.div.Scalar %[[VAL_38]], %[[VAL_22]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_40:.*]] = torch.aten.clone %[[VAL_38]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_41:.*]] = torch.aten.fill.Scalar %[[VAL_40]], %[[VAL_24]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_42:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_41]], %[[VAL_39]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_43:.*]] = torch.aten.sub.Scalar %[[VAL_42]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_44:.*]] = torch.aten.mul.Scalar %[[VAL_43]], %[[VAL_23]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_45:.*]] = torch.aten.add.Scalar %[[VAL_1]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_46:.*]] = torch.aten.item %[[VAL_45]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_47:.*]] = torch.aten.mul.Scalar %[[VAL_44]], %[[VAL_46]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_48:.*]] = torch.aten.div.Scalar %[[VAL_47]], %[[VAL_8]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_49:.*]] = torch.constant.int 3
// CHECK: %[[VAL_50:.*]] = torch.constant.bool false
// CHECK: %[[VAL_51:.*]] = torch.aten.to.dtype %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],si32>
// CHECK: %[[VAL_52:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_12]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_53:.*]] = torch.aten.unsqueeze %[[VAL_52]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_54:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_15]], %[[VAL_13]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_55:.*]] = torch.aten.unsqueeze %[[VAL_54]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_56:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_57:.*]] = torch.aten.unsqueeze %[[VAL_56]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_58:.*]] = torch.aten.sub.Tensor %[[VAL_55]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_59:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_55]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_60:.*]] = torch.aten.arange %[[VAL_20]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],f32>
// CHECK: %[[VAL_61:.*]] = torch.aten.unsqueeze %[[VAL_60]], %[[VAL_15]] : !torch.vtensor<[9],f32>, !torch.int -> !torch.vtensor<[9,1],f32>
// CHECK: %[[VAL_62:.*]] = torch.prim.ListConstruct %[[VAL_20]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_63:.*]] = torch.aten.expand %[[VAL_61]], %[[VAL_62]], %[[VAL_50]] : !torch.vtensor<[9,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_64:.*]] = torch.aten.eq.Scalar %[[VAL_58]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
// CHECK: %[[VAL_65:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_13]], %[[VAL_58]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_66:.*]] = torch.aten.gt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_67:.*]] = torch.aten.max %[[VAL_53]] : !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1],si32>
// CHECK: %[[VAL_68:.*]] = torch.aten.item %[[VAL_67]] : !torch.vtensor<[1],si32> -> !torch.int
// CHECK: %[[VAL_69:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_68]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_70:.*]] = torch.aten.sub.Tensor %[[VAL_69]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_71:.*]] = torch.aten.to.dtype %[[VAL_65]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAL_72:.*]] = torch.aten.div.Tensor %[[VAL_70]], %[[VAL_71]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_73:.*]] = torch.aten.gt.Scalar %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_74:.*]] = torch.aten.where.ScalarOther %[[VAL_73]], %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_75:.*]] = torch.prim.ListConstruct %[[VAL_66]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
// CHECK: %[[VAL_76:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAL_77:.*]] = torch.constant.none
// CHECK: %[[VAL_78:.*]] = torch.constant.int 6
// CHECK: %[[VAL_79:.*]] = torch.aten.full %[[VAL_76]], %[[VAL_14]], %[[VAL_78]], %[[VAL_77]], %[[VAL_77]], %[[VAL_77]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_80:.*]] = torch.aten.index_put %[[VAL_74]], %[[VAL_75]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_81:.*]] = torch.prim.ListConstruct %[[VAL_64]] : (!torch.vtensor<[1,8],i1>) -> !torch.list<vtensor<[1,8],i1>>
// CHECK: %[[VAL_82:.*]] = torch.aten.index.Tensor %[[VAL_55]], %[[VAL_81]] : !torch.vtensor<[1,8],si32>, !torch.list<vtensor<[1,8],i1>> -> !torch.vtensor<[?],si32>
// CHECK: %[[VAL_83:.*]] = torch.aten.squeeze %[[VAL_64]] : !torch.vtensor<[1,8],i1> -> !torch.vtensor<[8],i1>
// CHECK: %[[VAL_84:.*]] = torch.aten.to.dtype %[[VAL_83]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[8],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_85:.*]] = torch.prim.ListConstruct %[[VAL_82]], %[[VAL_84]] : (!torch.vtensor<[?],si32>, !torch.vtensor<[8],si32>) -> !torch.list<vtensor<[?],si32>>
// CHECK: %[[VAL_86:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAL_87:.*]] = torch.constant.none
// CHECK: %[[VAL_88:.*]] = torch.constant.int 6
// CHECK: %[[VAL_89:.*]] = torch.aten.full %[[VAL_86]], %[[VAL_15]], %[[VAL_88]], %[[VAL_87]], %[[VAL_87]], %[[VAL_87]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_90:.*]] = torch.aten.index_put %[[VAL_80]], %[[VAL_85]], %[[VAL_89]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[?],si32>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_91:.*]] = torch.aten.eq.Scalar %[[VAL_59]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
// CHECK: %[[VAL_92:.*]] = torch.aten.lt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_93:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_13]], %[[VAL_59]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_94:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAL_95:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_14]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_96:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_95]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_97:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[VAL_94]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_98:.*]] = torch.aten.gt.Scalar %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_99:.*]] = torch.aten.where.ScalarOther %[[VAL_98]], %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_100:.*]] = torch.prim.ListConstruct %[[VAL_92]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
// CHECK: %[[VAL_101:.*]] = torch.aten.index_put %[[VAL_99]], %[[VAL_100]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_102:.*]] = torch.aten.ne.Scalar %[[VAL_101]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_103:.*]] = torch.prim.ListConstruct %[[VAL_102]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
// CHECK: %[[VAL_104:.*]] = torch.aten.index_put %[[VAL_90]], %[[VAL_103]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_105:.*]] = torch.aten.add.Tensor %[[VAL_104]], %[[VAL_101]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_106:.*]] = torch.constant.int 6
// CHECK: %[[VAL_107:.*]] = torch.aten.to.dtype %[[VAL_105]], %[[VAL_106]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
// CHECK: return %[[VAL_107]] : !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_33:.*]] = torch.aten.log %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[HIGH_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_34]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_36:.*]] = torch.aten.sub.Tensor %[[HIGH_FREQ_MEL]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_37:.*]] = torch.aten.add.int %[[NUM_MEL_BINS_ITEM]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEL_STEP:.*]] = torch.aten.div.Scalar %[[VAL_36]], %[[VAL_37]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[LOW_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[CENTER_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[HIGH_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_42:.*]] = torch.aten.add.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_43:.*]] = torch.aten.item %[[VAL_42]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_44:.*]] = torch.constant.bool false
// CHECK: %[[VAL_45:.*]] = torch.aten.mul.Tensor %[[LOW_BINS_INIT]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_46:.*]] = torch.aten.add.Tensor %[[VAL_45]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_47:.*]] = torch.aten.div.Scalar %[[VAL_46]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_48:.*]] = torch.aten.clone %[[VAL_46]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_49:.*]] = torch.aten.fill.Scalar %[[VAL_48]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_50:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_49]], %[[VAL_47]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_51:.*]] = torch.aten.sub.Scalar %[[VAL_50]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_52:.*]] = torch.aten.mul.Scalar %[[VAL_51]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_53:.*]] = torch.aten.mul.Scalar %[[VAL_52]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_54:.*]] = torch.aten.div.Scalar %[[VAL_53]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_55:.*]] = torch.aten.to.dtype %[[VAL_54]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[LOW_BINS:.*]] = torch.aten.unsqueeze %[[VAL_55]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_57:.*]] = torch.aten.add.Scalar %[[CENTER_BINS_INIT]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_58:.*]] = torch.aten.mul.Tensor %[[VAL_57]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_59:.*]] = torch.aten.add.Tensor %[[VAL_58]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_60:.*]] = torch.aten.div.Scalar %[[VAL_59]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_61:.*]] = torch.aten.clone %[[VAL_59]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_62:.*]] = torch.aten.fill.Scalar %[[VAL_61]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_63:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_62]], %[[VAL_60]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_64:.*]] = torch.aten.sub.Scalar %[[VAL_63]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_65:.*]] = torch.aten.mul.Scalar %[[VAL_64]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_66:.*]] = torch.aten.mul.Scalar %[[VAL_65]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_67:.*]] = torch.aten.div.Scalar %[[VAL_66]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_68:.*]] = torch.aten.to.dtype %[[VAL_67]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[CENTER_BINS:.*]] = torch.aten.unsqueeze %[[VAL_68]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_70:.*]] = torch.aten.add.Scalar %[[HIGH_BINS_INIT]], %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_71:.*]] = torch.aten.mul.Tensor %[[VAL_70]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_72:.*]] = torch.aten.add.Tensor %[[VAL_71]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_73:.*]] = torch.aten.div.Scalar %[[VAL_72]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_74:.*]] = torch.aten.clone %[[VAL_72]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_75:.*]] = torch.aten.fill.Scalar %[[VAL_74]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_76:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_75]], %[[VAL_73]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_77:.*]] = torch.aten.sub.Scalar %[[VAL_76]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_78:.*]] = torch.aten.mul.Scalar %[[VAL_77]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_79:.*]] = torch.aten.mul.Scalar %[[VAL_78]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_80:.*]] = torch.aten.div.Scalar %[[VAL_79]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
// CHECK: %[[VAL_81:.*]] = torch.aten.to.dtype %[[VAL_80]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[HIGH_BINS:.*]] = torch.aten.unsqueeze %[[VAL_81]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[IOTA_INIT:.*]] = torch.aten.arange %[[NUM_SPECTROGRAM_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],si32>
// CHECK: %[[IOTA:.*]] = torch.aten.unsqueeze %[[IOTA_INIT]], %[[VAL_12]] : !torch.vtensor<[9],si32>, !torch.int -> !torch.vtensor<[9,1],si32>
// CHECK: %[[LOW_TO_CENTER:.*]] = torch.aten.sub.Tensor %[[CENTER_BINS]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[CENTER_TO_HIGH:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[CENTER_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_87:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAL_88:.*]] = torch.constant.none
// CHECK: %[[VAL_89:.*]] = torch.constant.int 6
// CHECK: %[[VAL_90:.*]] = torch.aten.full %[[VAL_87]], %[[VAL_12]], %[[VAL_89]], %[[VAL_88]], %[[VAL_88]], %[[VAL_88]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_91:.*]] = torch.aten.maximum %[[VAL_90]], %[[LOW_TO_CENTER]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
// CHECK: %[[UP_SCALE:.*]] = torch.aten.to.dtype %[[VAL_91]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAL_93:.*]] = torch.aten.maximum %[[VAL_90]], %[[CENTER_TO_HIGH]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
// CHECK: %[[DOWN_SCALE:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAL_95:.*]] = torch.aten.sub.Tensor %[[IOTA]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],si32>
// CHECK: %[[VAL_96:.*]] = torch.aten.to.dtype %[[VAL_95]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
// CHECK: %[[RAMP_UP:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[UP_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_98:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[IOTA]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,1],si32>, !torch.int -> !torch.vtensor<[9,8],si32>
// CHECK: %[[VAL_99:.*]] = torch.aten.to.dtype %[[VAL_98]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
// CHECK: %[[RAMP_DOWN:.*]] = torch.aten.div.Tensor %[[VAL_99]], %[[DOWN_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_101:.*]] = torch.aten.ge.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_102:.*]] = torch.aten.eq.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_103:.*]] = torch.aten.lt.Tensor %[[IOTA]], %[[LOW_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_104:.*]] = torch.aten.gt.Tensor %[[IOTA]], %[[HIGH_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[RAMP_INIT:.*]] = torch.aten.where.self %[[VAL_101]], %[[RAMP_DOWN]], %[[RAMP_UP]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_106:.*]] = torch.aten.where.ScalarSelf %[[VAL_103]], %[[VAL_11]], %[[RAMP_INIT]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_107:.*]] = torch.aten.where.ScalarSelf %[[VAL_104]], %[[VAL_11]], %[[VAL_106]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_108:.*]] = torch.aten.eq.Scalar %[[CENTER_TO_HIGH]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
// CHECK: %[[CORNER_CASES:.*]] = torch.aten.logical_and %[[VAL_102]], %[[VAL_108]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[1,8],i1> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[RAMP:.*]] = torch.aten.where.ScalarSelf %[[CORNER_CASES]], %[[VAL_22]], %[[VAL_107]] : !torch.vtensor<[9,8],i1>, !torch.float, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_111:.*]] = torch.constant.int 6
// CHECK: %[[OUTPUT:.*]] = torch.aten.to.dtype %[[RAMP]], %[[VAL_111]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
// CHECK: return %[[OUTPUT]] : !torch.vtensor<[9,8],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.MelWeightMatrix"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32>
return %0 : !torch.vtensor<[9,8],f32>