mirror of https://github.com/llvm/torch-mlir
onnx.MelWeightMatrix Onnx to Torch to Linalg (#3659)
- This PR adds new (and equivalent) more tensorized impl of MelWeightMatrix which lowers all the way to linalg. - [Ref Pytorch Impl](https://gist.github.com/PhaneeshB/4e6dfcded3007b1b686fbe28f07a67cd) - Thanks to @rsuderman for pointing out the difficulties [earlier impl](#3503) posed during lowering to linalg and also for providing a better numpy impl 🙏pull/3636/merge
parent
fcc5f444cd
commit
9a6fe58a02
|
@ -640,8 +640,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
Value numMelBinsItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, operands[0]);
|
||||
Value dftLengthItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, operands[1]);
|
||||
Value sampleRateItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, operands[2]);
|
||||
Value lowerEdgeHzItem =
|
||||
|
@ -656,9 +654,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
// Recurring shapes
|
||||
SmallVector<int64_t> unranked({});
|
||||
SmallVector<int64_t> shapeNMB({numMelBinsInt});
|
||||
SmallVector<int64_t> shapeNMBp2({numMelBinsInt + 2});
|
||||
SmallVector<int64_t> shape1xNMB({1, numMelBinsInt});
|
||||
SmallVector<int64_t> shapeNSB({numSpectrogramBinsInt});
|
||||
SmallVector<int64_t> shapeNSBx1({numSpectrogramBinsInt, 1});
|
||||
SmallVector<int64_t> shapeNSBxNMB(
|
||||
{numSpectrogramBinsInt, numMelBinsInt});
|
||||
|
||||
|
@ -671,37 +669,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
// Value constants
|
||||
Value noneConst = b.create<Torch::ConstantNoneOp>();
|
||||
Value negTwoConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-2));
|
||||
Value negOneConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-1));
|
||||
Value zeroConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(0));
|
||||
Value oneConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(1));
|
||||
Value twoConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(2));
|
||||
Value int32DTypeConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
|
||||
Value float32DTypeConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(6));
|
||||
|
||||
Torch::ValueTensorType dftLenType =
|
||||
Torch::ValueTensorType::get(ctx, unranked, inpIntDType);
|
||||
Type freqBinsIntType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNMBp2, si32Ty);
|
||||
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
|
||||
Type freqBinsFltType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNMBp2, f32Ty);
|
||||
Torch::ValueTensorType::get(ctx, shapeNMB, f32Ty);
|
||||
|
||||
Value dftLengthDivTwoFlt =
|
||||
b.create<Torch::AtenDivIntOp>(dftLengthItem, twoConst);
|
||||
Value dftLengthDivTwo =
|
||||
b.create<Torch::AtenIntFloatOp>(dftLengthDivTwoFlt);
|
||||
Value numSpectrogramBins =
|
||||
b.create<Torch::AtenAddIntOp>(dftLengthDivTwo, oneConst);
|
||||
Value numSpectrogramBinsItem = numSpectrogramBins;
|
||||
Value freqBinsInit = b.create<Torch::AtenArangeOp>(
|
||||
freqBinsIntType, numMelBinsItem, /*dtype=*/float32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
Value dftLengthDivTwoTensor = b.create<Torch::AtenFloorDivideScalarOp>(
|
||||
dftLenType, operands[1], twoConst);
|
||||
Value numSpectrogramBinsTensor = b.create<Torch::AtenAddScalarOp>(
|
||||
dftLenType, dftLengthDivTwoTensor, oneConst, /*alpha =*/oneConst);
|
||||
Value numSpectrogramBinsItem = getItemOp<Torch::IntType>(
|
||||
binder, rewriter, numSpectrogramBinsTensor);
|
||||
|
||||
// From Ref Impl of Onnx.MelWeightMatrix:
|
||||
// https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32
|
||||
|
@ -712,6 +703,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(700));
|
||||
Value tenConst =
|
||||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(10));
|
||||
Value oneFltConst =
|
||||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(1));
|
||||
Value LnToLog10Const = b.create<Torch::ConstantFloatOp>(
|
||||
rewriter.getF64FloatAttr(M_LOG10E));
|
||||
|
||||
Value lfDiv7Hfloat =
|
||||
b.create<Torch::AtenDivFloatOp>(lowerEdgeHzItem, sevenHConst);
|
||||
|
@ -720,8 +715,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
b.create<Torch::PrimNumToTensorScalarOp>(freqType, lfDiv7Hfloat);
|
||||
Value lfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
|
||||
freqType, lfDiv7H, oneConst, /*alpha =*/oneConst);
|
||||
Value lfDiv7HAdd1Log10 =
|
||||
b.create<Torch::AtenLog10Op>(freqType, lfDiv7HAdd1);
|
||||
Value lfDiv7HAdd1Ln = b.create<Torch::AtenLogOp>(freqType, lfDiv7HAdd1);
|
||||
Value lfDiv7HAdd1Log10 = b.create<Torch::AtenMulScalarOp>(
|
||||
freqType, lfDiv7HAdd1Ln, LnToLog10Const);
|
||||
|
||||
Value lfMel = b.create<Torch::AtenMulScalarOp>(
|
||||
freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst);
|
||||
|
||||
|
@ -731,226 +728,235 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
b.create<Torch::PrimNumToTensorScalarOp>(freqType, hfDiv7Hfloat);
|
||||
Value hfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
|
||||
freqType, hfDiv7H, oneConst, /*alpha =*/oneConst);
|
||||
Value hfDiv7HAdd1Log10 =
|
||||
b.create<Torch::AtenLog10Op>(freqType, hfDiv7HAdd1);
|
||||
Value hfDiv7HAdd1Ln = b.create<Torch::AtenLogOp>(freqType, hfDiv7HAdd1);
|
||||
Value hfDiv7HAdd1Log10 = b.create<Torch::AtenMulScalarOp>(
|
||||
freqType, hfDiv7HAdd1Ln, LnToLog10Const);
|
||||
|
||||
Value hfMel = b.create<Torch::AtenMulScalarOp>(
|
||||
freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst);
|
||||
|
||||
Value hfSubLf = b.create<Torch::AtenSubTensorOp>(
|
||||
hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst);
|
||||
Value numMelBinsPlus2 =
|
||||
b.create<Torch::AtenAddIntOp>(numMelBinsItem, twoConst);
|
||||
Value melStep = b.create<Torch::AtenDivScalarOp>(
|
||||
hfSubLf.getType(), hfSubLf, numMelBinsItem);
|
||||
hfSubLf.getType(), hfSubLf, numMelBinsPlus2);
|
||||
|
||||
Value freqBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
|
||||
freqBinsFltType, freqBinsInit, melStep);
|
||||
Value freqBinsScaled = b.create<Torch::AtenAddTensorOp>(
|
||||
freqBinsFltType, freqBinsMulMelStep, lfMel, /*alpha=*/oneConst);
|
||||
Value lowBinsInit = b.create<Torch::AtenArangeOp>(
|
||||
freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
|
||||
// Mel to Hz conv
|
||||
Value centerBinsInit = b.create<Torch::AtenArangeOp>(
|
||||
freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
|
||||
Value fbDiv = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst);
|
||||
Value fbClone = b.create<Torch::AtenCloneOp>(
|
||||
freqBinsFltType, freqBinsScaled, /*memory_format=*/noneConst);
|
||||
Value tenTensor = b.create<Torch::AtenFillScalarOp>(freqBinsFltType,
|
||||
fbClone, tenConst);
|
||||
Value fbPow = b.create<Torch::AtenPowTensorTensorOp>(freqBinsFltType,
|
||||
tenTensor, fbDiv);
|
||||
Value fbPowSubOne = b.create<Torch::AtenSubScalarOp>(
|
||||
freqBinsFltType, fbPow, oneConst, /*alpha=*/oneConst);
|
||||
Value freqBinsHz = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, fbPowSubOne, sevenHConst);
|
||||
Value highBinsInit = b.create<Torch::AtenArangeOp>(
|
||||
freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
|
||||
// Normalize freqBinsHz
|
||||
// Common values used in conversion
|
||||
Value dftLenPlusOne = b.create<Torch::AtenAddScalarOp>(
|
||||
dftLenType, operands[1], oneConst, /*alpha=*/oneConst);
|
||||
Value dftLenPlusOneItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, dftLenPlusOne);
|
||||
Value fbMulDft = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, freqBinsHz, dftLenPlusOneItem);
|
||||
Value freqBinsNormalized = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, fbMulDft, sampleRateItem);
|
||||
|
||||
// cast to int32
|
||||
Value int32DTypeConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
|
||||
Value falseConst = b.create<Torch::ConstantBoolOp>(false);
|
||||
Value freqBins = b.create<Torch::AtenToDtypeOp>(
|
||||
freqBinsIntType, freqBinsNormalized, /*dtype=*/int32DTypeConst,
|
||||
Torch::ValueTensorType unsqueezeBinsResType =
|
||||
Torch::ValueTensorType::get(ctx, shape1xNMB, si32Ty);
|
||||
|
||||
// Low bins Mel to hz
|
||||
Value lowBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
|
||||
freqBinsFltType, lowBinsInit, melStep);
|
||||
Value lowBinsScaled = b.create<Torch::AtenAddTensorOp>(
|
||||
freqBinsFltType, lowBinsMulMelStep, lfMel, /*alpha=*/oneConst);
|
||||
Value lbDiv = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, lowBinsScaled, twoFiveNineFiveConst);
|
||||
Value lbClone = b.create<Torch::AtenCloneOp>(
|
||||
freqBinsFltType, lowBinsScaled, /*memory_format=*/noneConst);
|
||||
Value lbTenTensor = b.create<Torch::AtenFillScalarOp>(
|
||||
freqBinsFltType, lbClone, tenConst);
|
||||
Value lbPow = b.create<Torch::AtenPowTensorTensorOp>(
|
||||
freqBinsFltType, lbTenTensor, lbDiv);
|
||||
Value lbPowSubOne = b.create<Torch::AtenSubScalarOp>(
|
||||
freqBinsFltType, lbPow, oneConst, /*alpha=*/oneConst);
|
||||
Value lowBinsHz = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, lbPowSubOne, sevenHConst);
|
||||
// Normalize freqBinsHz
|
||||
Value lbMulDft = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, lowBinsHz, dftLenPlusOneItem);
|
||||
Value lowBinsNormalized = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, lbMulDft, sampleRateItem);
|
||||
// cast to int32
|
||||
Value lowBinsInt = b.create<Torch::AtenToDtypeOp>(
|
||||
freqBinsIntType, lowBinsNormalized, /*dtype=*/int32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value lowBins = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeBinsResType, lowBinsInt, /*dim=*/zeroConst);
|
||||
|
||||
Torch::ValueTensorType sliceResType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
|
||||
Type unsqueezeResType =
|
||||
sliceResType.getWithSizesAndDtype(shape1xNMB, si32Ty);
|
||||
Value lfTensor = b.create<Torch::AtenSliceTensorOp>(
|
||||
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
|
||||
/*end=*/negTwoConst, /*step=*/oneConst);
|
||||
Value lowFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeResType, lfTensor, /*dim=*/zeroConst);
|
||||
// Center bins mel to hz
|
||||
Value centerBinsInitInc = b.create<Torch::AtenAddScalarOp>(
|
||||
freqBinsIntType, centerBinsInit, oneConst, /*alpha=*/oneConst);
|
||||
Value centerBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
|
||||
freqBinsFltType, centerBinsInitInc, melStep);
|
||||
Value centerBinsScaled = b.create<Torch::AtenAddTensorOp>(
|
||||
freqBinsFltType, centerBinsMulMelStep, lfMel, /*alpha=*/oneConst);
|
||||
Value cbDiv = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, centerBinsScaled, twoFiveNineFiveConst);
|
||||
Value cbClone = b.create<Torch::AtenCloneOp>(
|
||||
freqBinsFltType, centerBinsScaled, /*memory_format=*/noneConst);
|
||||
Value cbTenTensor = b.create<Torch::AtenFillScalarOp>(
|
||||
freqBinsFltType, cbClone, tenConst);
|
||||
Value cbPow = b.create<Torch::AtenPowTensorTensorOp>(
|
||||
freqBinsFltType, cbTenTensor, cbDiv);
|
||||
Value cbPowSubOne = b.create<Torch::AtenSubScalarOp>(
|
||||
freqBinsFltType, cbPow, oneConst, /*alpha=*/oneConst);
|
||||
Value centerBinsHz = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, cbPowSubOne, sevenHConst);
|
||||
// Normalize freqBinsHz
|
||||
Value cbMulDft = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, centerBinsHz, dftLenPlusOneItem);
|
||||
Value centerBinsNormalized = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, cbMulDft, sampleRateItem);
|
||||
// cast to int32
|
||||
Value centerBinsInt = b.create<Torch::AtenToDtypeOp>(
|
||||
freqBinsIntType, centerBinsNormalized, /*dtype=*/int32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value centerBins = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeBinsResType, centerBinsInt, /*dim=*/zeroConst);
|
||||
|
||||
Value cfTensor = b.create<Torch::AtenSliceTensorOp>(
|
||||
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/oneConst,
|
||||
/*end=*/negOneConst, /*step=*/oneConst);
|
||||
Value centerFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeResType, cfTensor, /*dim=*/zeroConst);
|
||||
// High bins mel to hz
|
||||
Value highBinsInitInc = b.create<Torch::AtenAddScalarOp>(
|
||||
freqBinsIntType, highBinsInit, twoConst, /*alpha=*/oneConst);
|
||||
Value highBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
|
||||
freqBinsFltType, highBinsInitInc, melStep);
|
||||
Value highBinsScaled = b.create<Torch::AtenAddTensorOp>(
|
||||
freqBinsFltType, highBinsMulMelStep, lfMel, /*alpha=*/oneConst);
|
||||
Value hbDiv = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, highBinsScaled, twoFiveNineFiveConst);
|
||||
Value hbClone = b.create<Torch::AtenCloneOp>(
|
||||
freqBinsFltType, highBinsScaled, /*memory_format=*/noneConst);
|
||||
Value hbTenTensor = b.create<Torch::AtenFillScalarOp>(
|
||||
freqBinsFltType, hbClone, tenConst);
|
||||
Value hbPow = b.create<Torch::AtenPowTensorTensorOp>(
|
||||
freqBinsFltType, hbTenTensor, hbDiv);
|
||||
Value hbPowSubOne = b.create<Torch::AtenSubScalarOp>(
|
||||
freqBinsFltType, hbPow, oneConst, /*alpha=*/oneConst);
|
||||
Value highBinsHz = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, hbPowSubOne, sevenHConst);
|
||||
// Normalize freqBinsHz
|
||||
Value hbMulDft = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, highBinsHz, dftLenPlusOneItem);
|
||||
Value highBinsNormalized = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, hbMulDft, sampleRateItem);
|
||||
// cast to int32
|
||||
Value highBinsInt = b.create<Torch::AtenToDtypeOp>(
|
||||
freqBinsIntType, highBinsNormalized, /*dtype=*/int32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value highBins = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeBinsResType, highBinsInt, /*dim=*/zeroConst);
|
||||
|
||||
Value hfTensor = b.create<Torch::AtenSliceTensorOp>(
|
||||
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
|
||||
/*end=*/noneConst, /*step=*/oneConst);
|
||||
Value highFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeResType, hfTensor, /*dim=*/zeroConst);
|
||||
|
||||
Value lowToCenter =
|
||||
b.create<Torch::AtenSubTensorOp>(unsqueezeResType, centerFreqTensor,
|
||||
lowFreqTensor, /*alpha=*/oneConst);
|
||||
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
|
||||
unsqueezeResType, highFreqTensor, centerFreqTensor,
|
||||
/*alpha=*/oneConst);
|
||||
|
||||
Type zeroToNInitType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSB, f32Ty);
|
||||
Value zeroToNInit = b.create<Torch::AtenArangeOp>(
|
||||
zeroToNInitType, numSpectrogramBinsItem,
|
||||
/*dtype=*/float32DTypeConst,
|
||||
Type iotaInitType = inputIntType.getWithSizesAndDtype(shapeNSB, si32Ty);
|
||||
Value iotaInit = b.create<Torch::AtenArangeOp>(
|
||||
iotaInitType, numSpectrogramBinsItem,
|
||||
/*dtype=*/int32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
|
||||
Type zeroToNBaseType = inputIntType.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>{numSpectrogramBinsInt, 1}, f32Ty);
|
||||
Value zeroToNBase = b.create<Torch::AtenUnsqueezeOp>(
|
||||
zeroToNBaseType, zeroToNInit, /*dim=*/oneConst);
|
||||
Type zeroToNumElesType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
|
||||
Value expandShapeList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
SmallVector<Value>{numSpectrogramBinsItem, numMelBinsItem});
|
||||
Value zeroToNumEles = b.create<Torch::AtenExpandOp>(
|
||||
zeroToNumElesType, zeroToNBase, expandShapeList,
|
||||
/*implicit=*/falseConst);
|
||||
Torch::ValueTensorType unsqueezeIotaResType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNSBx1, si32Ty);
|
||||
Value iota = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeIotaResType, iotaInit, /*dim=*/oneConst);
|
||||
|
||||
Type maskType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
|
||||
Value maskLowToCenterZero =
|
||||
b.create<Torch::AtenEqScalarOp>(maskType, lowToCenter, zeroConst);
|
||||
Value lowToCenter = b.create<Torch::AtenSubTensorOp>(
|
||||
unsqueezeBinsResType, centerBins, lowBins, /*alpha=*/oneConst);
|
||||
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
|
||||
unsqueezeBinsResType, highBins, centerBins, /*alpha=*/oneConst);
|
||||
|
||||
// L2C computation
|
||||
Value lowToCenterNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter);
|
||||
Type maskL2CAfterCType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
|
||||
Value maskL2CAfterC = b.create<Torch::AtenGtTensorOp>(
|
||||
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
|
||||
Type maxLFResTy =
|
||||
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{1}, si32Ty);
|
||||
Value maxLowerFreq =
|
||||
b.create<Torch::AtenMaxOp>(maxLFResTy, lowFreqTensor);
|
||||
Value maxLowerFreqItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, maxLowerFreq);
|
||||
Value zeroToNumElesL2C = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem,
|
||||
zeroToNumEles);
|
||||
Value upslopeDiff = b.create<Torch::AtenSubTensorOp>(
|
||||
zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor,
|
||||
/*alpha=*/oneConst);
|
||||
Type l2cNZFltTy = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
|
||||
Value l2cNZFlt = b.create<Torch::AtenToDtypeOp>(
|
||||
l2cNZFltTy, lowToCenterNoZero, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value upslopeL2C0 = b.create<Torch::AtenDivTensorOp>(
|
||||
zeroToNumElesType, upslopeDiff, l2cNZFlt);
|
||||
Type maskUpslopeL2C0PosType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
|
||||
Value maskUpslopeL2C0Pos = b.create<Torch::AtenGtScalarOp>(
|
||||
maskUpslopeL2C0PosType, upslopeL2C0, zeroConst);
|
||||
Value upslopeL2C0PosRanged = b.create<Torch::AtenWhereScalarOtherOp>(
|
||||
zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst);
|
||||
Value maskIdxL2CAfterCList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskL2CAfterC.getType()),
|
||||
ValueRange{maskL2CAfterC});
|
||||
Value zeroConstTensor = Torch::createRank0Tensor(
|
||||
rewriter, binder.getLoc(),
|
||||
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), zeroConst);
|
||||
Value upslopeL2C1 = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList,
|
||||
zeroConstTensor, falseConst);
|
||||
Value maskIdxL2CZeroList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskLowToCenterZero.getType()),
|
||||
ValueRange{maskLowToCenterZero});
|
||||
Type centerFreqTensorL2CZeroType =
|
||||
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{-1}, si32Ty);
|
||||
Value centerFreqTensorL2CZero = b.create<Torch::AtenIndexTensorOp>(
|
||||
centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList);
|
||||
Type maskSqueezeType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNMB, i1Ty);
|
||||
Value maskLowToCenterZeroSqueeze = b.create<Torch::AtenSqueezeOp>(
|
||||
maskSqueezeType, maskLowToCenterZero);
|
||||
Type maskL2CIntTy = inputIntType.getWithSizesAndDtype(shapeNMB, si32Ty);
|
||||
Value maskLowToCenterInt = b.create<Torch::AtenToDtypeOp>(
|
||||
maskL2CIntTy, maskLowToCenterZeroSqueeze, /*dtype=*/int32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value upslopeOneIdxList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(
|
||||
centerFreqTensorL2CZero.getType()),
|
||||
ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt});
|
||||
Value oneConstTensor = Torch::createRank0Tensor(
|
||||
rewriter, binder.getLoc(),
|
||||
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst);
|
||||
Value upslopeL2C = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor,
|
||||
falseConst);
|
||||
|
||||
// H2C computation
|
||||
Value maskCenterToHighZero =
|
||||
b.create<Torch::AtenEqScalarOp>(maskType, centerToHigh, zeroConst);
|
||||
Value maskH2CBeforeC = b.create<Torch::AtenLtTensorOp>(
|
||||
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
|
||||
Value centerToHighNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh);
|
||||
Value c2hNZFlt = b.create<Torch::AtenToDtypeOp>(
|
||||
l2cNZFltTy, centerToHighNoZero, /*dtype=*/float32DTypeConst,
|
||||
Type scaledType = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
|
||||
Value upscaleInit = b.create<Torch::AtenMaximumOp>(
|
||||
unsqueezeBinsResType, oneConstTensor, lowToCenter);
|
||||
Value upscale = b.create<Torch::AtenToDtypeOp>(
|
||||
scaledType, upscaleInit, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value zeroToNumElesC2H = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles);
|
||||
Value downslopeDiff = b.create<Torch::AtenSubTensorOp>(
|
||||
zeroToNumElesType, highFreqTensor, zeroToNumElesC2H,
|
||||
/*alpha=*/oneConst);
|
||||
Value downslopeC2H0 = b.create<Torch::AtenDivTensorOp>(
|
||||
zeroToNumElesType, downslopeDiff, c2hNZFlt);
|
||||
Value maskDownslopeC2H0Pos = b.create<Torch::AtenGtScalarOp>(
|
||||
maskUpslopeL2C0PosType, downslopeC2H0, zeroConst);
|
||||
Value downslopeC2H0Pos = b.create<Torch::AtenWhereScalarOtherOp>(
|
||||
zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst);
|
||||
Value idxH2CBeforeCList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskH2CBeforeC.getType()),
|
||||
ValueRange{maskH2CBeforeC});
|
||||
Value downslopeC2H = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList,
|
||||
zeroConstTensor, falseConst);
|
||||
|
||||
// final result Calculation
|
||||
Value maskH2CNonZero = b.create<Torch::AtenNeScalarOp>(
|
||||
maskL2CAfterCType, downslopeC2H, zeroConst);
|
||||
Value idxH2CNZList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskH2CNonZero.getType()),
|
||||
ValueRange{maskH2CNonZero});
|
||||
Value upslopeL2CMasked = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor,
|
||||
falseConst);
|
||||
Value downscaleInit = b.create<Torch::AtenMaximumOp>(
|
||||
unsqueezeBinsResType, oneConstTensor, centerToHigh);
|
||||
Value downscale = b.create<Torch::AtenToDtypeOp>(
|
||||
scaledType, downscaleInit, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
|
||||
Value slopesFinal = b.create<Torch::AtenAddTensorOp>(
|
||||
zeroToNumElesType, upslopeL2CMasked, downslopeC2H,
|
||||
/*alpha=*/oneConst);
|
||||
Torch::ValueTensorType binsDiffType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNSBxNMB, si32Ty);
|
||||
Torch::ValueTensorType diffFloatType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNSBxNMB, f32Ty);
|
||||
|
||||
Value iotaSubLBInt = b.create<Torch::AtenSubTensorOp>(
|
||||
binsDiffType, iota, lowBins, /*alpha=*/oneConst);
|
||||
Value iotaSubLB = b.create<Torch::AtenToDtypeOp>(
|
||||
diffFloatType, iotaSubLBInt, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value rampUp =
|
||||
b.create<Torch::AtenDivTensorOp>(diffFloatType, iotaSubLB, upscale);
|
||||
|
||||
Value hbSubIotaInt = b.create<Torch::AtenSubTensorOp>(
|
||||
binsDiffType, highBins, iota, /*alpha=*/oneConst);
|
||||
Value hbSubIota = b.create<Torch::AtenToDtypeOp>(
|
||||
diffFloatType, hbSubIotaInt, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value rampDown = b.create<Torch::AtenDivTensorOp>(diffFloatType,
|
||||
hbSubIota, downscale);
|
||||
|
||||
// ramp values
|
||||
Type iotaCmpBinsType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
|
||||
|
||||
// Iota Cmp Bins
|
||||
Value iotaGtEqCBins =
|
||||
b.create<Torch::AtenGeTensorOp>(iotaCmpBinsType, iota, centerBins);
|
||||
Value iotaEqCBins =
|
||||
b.create<Torch::AtenEqTensorOp>(iotaCmpBinsType, iota, centerBins);
|
||||
Value iotaLtLBins =
|
||||
b.create<Torch::AtenLtTensorOp>(iotaCmpBinsType, iota, lowBins);
|
||||
Value iotaGtLBins =
|
||||
b.create<Torch::AtenGtTensorOp>(iotaCmpBinsType, iota, highBins);
|
||||
|
||||
// Create output freq ramps Low-Center-High
|
||||
Type rampInitType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
|
||||
Value rampInit = b.create<Torch::AtenWhereSelfOp>(
|
||||
rampInitType, iotaGtEqCBins, rampDown, rampUp);
|
||||
Value rampInitLt = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
rampInitType, iotaLtLBins, zeroConst, rampInit);
|
||||
Value rampInitLtGt = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
rampInitType, iotaGtLBins, zeroConst, rampInitLt);
|
||||
|
||||
Type C2HCmpBinsType =
|
||||
inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
|
||||
Value C2HEqZero = b.create<Torch::AtenEqScalarOp>(
|
||||
C2HCmpBinsType, centerToHigh, zeroConst);
|
||||
Value cornerCases = b.create<Torch::AtenLogicalAndOp>(
|
||||
iotaCmpBinsType, iotaEqCBins, C2HEqZero);
|
||||
Value rampOutput = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
rampInitType, cornerCases, oneFltConst, rampInitLtGt);
|
||||
|
||||
Value outputDTypeConst = b.create<Torch::ConstantIntOp>(
|
||||
rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(torchDTypeInt.value()));
|
||||
Value finalOutput = b.create<Torch::AtenToDtypeOp>(
|
||||
resultType, slopesFinal, /*dtype=*/outputDTypeConst,
|
||||
resultType, rampOutput, /*dtype=*/outputDTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
|
||||
|
|
|
@ -1974,113 +1974,118 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
|
|||
|
||||
// CHECK-LABEL: func.func @test_mwm
|
||||
func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "test_mwm", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>, %[[VAL_1:.*]]: !torch.vtensor<[],si64>, %[[VAL_2:.*]]: !torch.vtensor<[],si64>,
|
||||
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>,
|
||||
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[],f32>
|
||||
// CHECK-SAME: %[[NUM_MEL_BINS_ARG:.*]]: !torch.vtensor<[],si64>, %[[DFT_LENGTH_ARG:.*]]: !torch.vtensor<[],si64>, %[[SAMPLE_RATE_ARG:.*]]: !torch.vtensor<[],si64>,
|
||||
// CHECK-SAME: %[[LOWER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32>,
|
||||
// CHECK-SAME: %[[UPPER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_8:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_9:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_10:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_11:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_12:.*]] = torch.constant.int -2
|
||||
// CHECK: %[[VAL_13:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[VAL_14:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_15:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_16:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_17:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_18:.*]] = torch.aten.div.int %[[VAL_7]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.float
|
||||
// CHECK: %[[VAL_19:.*]] = torch.aten.Int.float %[[VAL_18]] : !torch.float -> !torch.int
|
||||
// CHECK: %[[VAL_20:.*]] = torch.aten.add.int %[[VAL_19]], %[[VAL_15]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_21:.*]] = torch.aten.arange %[[VAL_6]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si32>
|
||||
// CHECK: %[[VAL_22:.*]] = torch.constant.float 2.595000e+03
|
||||
// CHECK: %[[VAL_23:.*]] = torch.constant.float 7.000000e+02
|
||||
// CHECK: %[[VAL_24:.*]] = torch.constant.float 1.000000e+01
|
||||
// CHECK: %[[VAL_25:.*]] = torch.aten.div.float %[[VAL_9]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[VAL_26:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_25]] : !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_27:.*]] = torch.aten.add.Scalar %[[VAL_26]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_28:.*]] = torch.aten.log10 %[[VAL_27]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_29:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[VAL_10]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[NUM_MEL_BINS_ITEM:.*]] = torch.aten.item %[[NUM_MEL_BINS_ARG]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[SAMPLE_RATE_ITEM:.*]] = torch.aten.item %[[SAMPLE_RATE_ARG]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[LOWER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[LOWER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[UPPER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[UPPER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_10:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_11:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_12:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_13:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_14:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_15:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_16:.*]] = torch.aten.floor_divide.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_13]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[VAL_17:.*]] = torch.aten.add.Scalar %[[VAL_16]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[NUM_SPECTROGRAM_BINS_ITEM:.*]] = torch.aten.item %[[VAL_17]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_19:.*]] = torch.constant.float 2.595000e+03
|
||||
// CHECK: %[[VAL_20:.*]] = torch.constant.float 7.000000e+02
|
||||
// CHECK: %[[VAL_21:.*]] = torch.constant.float 1.000000e+01
|
||||
// CHECK: %[[VAL_22:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[CONST_LN_TO_LOG10:.*]] = torch.constant.float 0.43429448190325182
|
||||
// CHECK: %[[VAL_24:.*]] = torch.aten.div.float %[[LOWER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[VAL_25:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_24]] : !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_26:.*]] = torch.aten.add.Scalar %[[VAL_25]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_27:.*]] = torch.aten.log %[[VAL_26]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_28:.*]] = torch.aten.mul.Scalar %[[VAL_27]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[LOW_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[UPPER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[VAL_31:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_30]] : !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_33:.*]] = torch.aten.log10 %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_35:.*]] = torch.aten.sub.Tensor %[[VAL_34]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_36:.*]] = torch.aten.div.Scalar %[[VAL_35]], %[[VAL_6]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_37:.*]] = torch.aten.mul.Tensor %[[VAL_21]], %[[VAL_36]] : !torch.vtensor<[10],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_38:.*]] = torch.aten.add.Tensor %[[VAL_37]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_39:.*]] = torch.aten.div.Scalar %[[VAL_38]], %[[VAL_22]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_40:.*]] = torch.aten.clone %[[VAL_38]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_41:.*]] = torch.aten.fill.Scalar %[[VAL_40]], %[[VAL_24]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_42:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_41]], %[[VAL_39]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_43:.*]] = torch.aten.sub.Scalar %[[VAL_42]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_44:.*]] = torch.aten.mul.Scalar %[[VAL_43]], %[[VAL_23]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_45:.*]] = torch.aten.add.Scalar %[[VAL_1]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[VAL_46:.*]] = torch.aten.item %[[VAL_45]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_47:.*]] = torch.aten.mul.Scalar %[[VAL_44]], %[[VAL_46]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_48:.*]] = torch.aten.div.Scalar %[[VAL_47]], %[[VAL_8]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_49:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_50:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_51:.*]] = torch.aten.to.dtype %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],si32>
|
||||
// CHECK: %[[VAL_52:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_12]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_53:.*]] = torch.aten.unsqueeze %[[VAL_52]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_54:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_15]], %[[VAL_13]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_55:.*]] = torch.aten.unsqueeze %[[VAL_54]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_56:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_57:.*]] = torch.aten.unsqueeze %[[VAL_56]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_58:.*]] = torch.aten.sub.Tensor %[[VAL_55]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_59:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_55]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_60:.*]] = torch.aten.arange %[[VAL_20]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],f32>
|
||||
// CHECK: %[[VAL_61:.*]] = torch.aten.unsqueeze %[[VAL_60]], %[[VAL_15]] : !torch.vtensor<[9],f32>, !torch.int -> !torch.vtensor<[9,1],f32>
|
||||
// CHECK: %[[VAL_62:.*]] = torch.prim.ListConstruct %[[VAL_20]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_63:.*]] = torch.aten.expand %[[VAL_61]], %[[VAL_62]], %[[VAL_50]] : !torch.vtensor<[9,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_64:.*]] = torch.aten.eq.Scalar %[[VAL_58]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
|
||||
// CHECK: %[[VAL_65:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_13]], %[[VAL_58]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_66:.*]] = torch.aten.gt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_67:.*]] = torch.aten.max %[[VAL_53]] : !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1],si32>
|
||||
// CHECK: %[[VAL_68:.*]] = torch.aten.item %[[VAL_67]] : !torch.vtensor<[1],si32> -> !torch.int
|
||||
// CHECK: %[[VAL_69:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_68]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_70:.*]] = torch.aten.sub.Tensor %[[VAL_69]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_71:.*]] = torch.aten.to.dtype %[[VAL_65]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
|
||||
// CHECK: %[[VAL_72:.*]] = torch.aten.div.Tensor %[[VAL_70]], %[[VAL_71]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_73:.*]] = torch.aten.gt.Scalar %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_74:.*]] = torch.aten.where.ScalarOther %[[VAL_73]], %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_75:.*]] = torch.prim.ListConstruct %[[VAL_66]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
|
||||
// CHECK: %[[VAL_76:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[VAL_77:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_78:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_79:.*]] = torch.aten.full %[[VAL_76]], %[[VAL_14]], %[[VAL_78]], %[[VAL_77]], %[[VAL_77]], %[[VAL_77]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_80:.*]] = torch.aten.index_put %[[VAL_74]], %[[VAL_75]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_81:.*]] = torch.prim.ListConstruct %[[VAL_64]] : (!torch.vtensor<[1,8],i1>) -> !torch.list<vtensor<[1,8],i1>>
|
||||
// CHECK: %[[VAL_82:.*]] = torch.aten.index.Tensor %[[VAL_55]], %[[VAL_81]] : !torch.vtensor<[1,8],si32>, !torch.list<vtensor<[1,8],i1>> -> !torch.vtensor<[?],si32>
|
||||
// CHECK: %[[VAL_83:.*]] = torch.aten.squeeze %[[VAL_64]] : !torch.vtensor<[1,8],i1> -> !torch.vtensor<[8],i1>
|
||||
// CHECK: %[[VAL_84:.*]] = torch.aten.to.dtype %[[VAL_83]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[8],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_85:.*]] = torch.prim.ListConstruct %[[VAL_82]], %[[VAL_84]] : (!torch.vtensor<[?],si32>, !torch.vtensor<[8],si32>) -> !torch.list<vtensor<[?],si32>>
|
||||
// CHECK: %[[VAL_86:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[VAL_87:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_88:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_89:.*]] = torch.aten.full %[[VAL_86]], %[[VAL_15]], %[[VAL_88]], %[[VAL_87]], %[[VAL_87]], %[[VAL_87]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_90:.*]] = torch.aten.index_put %[[VAL_80]], %[[VAL_85]], %[[VAL_89]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[?],si32>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_91:.*]] = torch.aten.eq.Scalar %[[VAL_59]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
|
||||
// CHECK: %[[VAL_92:.*]] = torch.aten.lt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_93:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_13]], %[[VAL_59]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_94:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
|
||||
// CHECK: %[[VAL_95:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_14]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_96:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_95]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_97:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[VAL_94]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_98:.*]] = torch.aten.gt.Scalar %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_99:.*]] = torch.aten.where.ScalarOther %[[VAL_98]], %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_100:.*]] = torch.prim.ListConstruct %[[VAL_92]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
|
||||
// CHECK: %[[VAL_101:.*]] = torch.aten.index_put %[[VAL_99]], %[[VAL_100]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_102:.*]] = torch.aten.ne.Scalar %[[VAL_101]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_103:.*]] = torch.prim.ListConstruct %[[VAL_102]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
|
||||
// CHECK: %[[VAL_104:.*]] = torch.aten.index_put %[[VAL_90]], %[[VAL_103]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_105:.*]] = torch.aten.add.Tensor %[[VAL_104]], %[[VAL_101]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_106:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_107:.*]] = torch.aten.to.dtype %[[VAL_105]], %[[VAL_106]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: return %[[VAL_107]] : !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_33:.*]] = torch.aten.log %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[HIGH_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_34]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_36:.*]] = torch.aten.sub.Tensor %[[HIGH_FREQ_MEL]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_37:.*]] = torch.aten.add.int %[[NUM_MEL_BINS_ITEM]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[MEL_STEP:.*]] = torch.aten.div.Scalar %[[VAL_36]], %[[VAL_37]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[LOW_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[CENTER_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[HIGH_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_42:.*]] = torch.aten.add.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[VAL_43:.*]] = torch.aten.item %[[VAL_42]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_44:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_45:.*]] = torch.aten.mul.Tensor %[[LOW_BINS_INIT]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_46:.*]] = torch.aten.add.Tensor %[[VAL_45]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_47:.*]] = torch.aten.div.Scalar %[[VAL_46]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_48:.*]] = torch.aten.clone %[[VAL_46]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_49:.*]] = torch.aten.fill.Scalar %[[VAL_48]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_50:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_49]], %[[VAL_47]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_51:.*]] = torch.aten.sub.Scalar %[[VAL_50]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_52:.*]] = torch.aten.mul.Scalar %[[VAL_51]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_53:.*]] = torch.aten.mul.Scalar %[[VAL_52]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_54:.*]] = torch.aten.div.Scalar %[[VAL_53]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_55:.*]] = torch.aten.to.dtype %[[VAL_54]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[LOW_BINS:.*]] = torch.aten.unsqueeze %[[VAL_55]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_57:.*]] = torch.aten.add.Scalar %[[CENTER_BINS_INIT]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_58:.*]] = torch.aten.mul.Tensor %[[VAL_57]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_59:.*]] = torch.aten.add.Tensor %[[VAL_58]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_60:.*]] = torch.aten.div.Scalar %[[VAL_59]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_61:.*]] = torch.aten.clone %[[VAL_59]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_62:.*]] = torch.aten.fill.Scalar %[[VAL_61]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_63:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_62]], %[[VAL_60]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_64:.*]] = torch.aten.sub.Scalar %[[VAL_63]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_65:.*]] = torch.aten.mul.Scalar %[[VAL_64]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_66:.*]] = torch.aten.mul.Scalar %[[VAL_65]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_67:.*]] = torch.aten.div.Scalar %[[VAL_66]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_68:.*]] = torch.aten.to.dtype %[[VAL_67]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[CENTER_BINS:.*]] = torch.aten.unsqueeze %[[VAL_68]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_70:.*]] = torch.aten.add.Scalar %[[HIGH_BINS_INIT]], %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_71:.*]] = torch.aten.mul.Tensor %[[VAL_70]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_72:.*]] = torch.aten.add.Tensor %[[VAL_71]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_73:.*]] = torch.aten.div.Scalar %[[VAL_72]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_74:.*]] = torch.aten.clone %[[VAL_72]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_75:.*]] = torch.aten.fill.Scalar %[[VAL_74]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_76:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_75]], %[[VAL_73]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_77:.*]] = torch.aten.sub.Scalar %[[VAL_76]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_78:.*]] = torch.aten.mul.Scalar %[[VAL_77]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_79:.*]] = torch.aten.mul.Scalar %[[VAL_78]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_80:.*]] = torch.aten.div.Scalar %[[VAL_79]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32>
|
||||
// CHECK: %[[VAL_81:.*]] = torch.aten.to.dtype %[[VAL_80]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[HIGH_BINS:.*]] = torch.aten.unsqueeze %[[VAL_81]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[IOTA_INIT:.*]] = torch.aten.arange %[[NUM_SPECTROGRAM_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],si32>
|
||||
// CHECK: %[[IOTA:.*]] = torch.aten.unsqueeze %[[IOTA_INIT]], %[[VAL_12]] : !torch.vtensor<[9],si32>, !torch.int -> !torch.vtensor<[9,1],si32>
|
||||
// CHECK: %[[LOW_TO_CENTER:.*]] = torch.aten.sub.Tensor %[[CENTER_BINS]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[CENTER_TO_HIGH:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[CENTER_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_87:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[VAL_88:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_89:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_90:.*]] = torch.aten.full %[[VAL_87]], %[[VAL_12]], %[[VAL_89]], %[[VAL_88]], %[[VAL_88]], %[[VAL_88]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_91:.*]] = torch.aten.maximum %[[VAL_90]], %[[LOW_TO_CENTER]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[UP_SCALE:.*]] = torch.aten.to.dtype %[[VAL_91]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
|
||||
// CHECK: %[[VAL_93:.*]] = torch.aten.maximum %[[VAL_90]], %[[CENTER_TO_HIGH]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[DOWN_SCALE:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
|
||||
// CHECK: %[[VAL_95:.*]] = torch.aten.sub.Tensor %[[IOTA]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],si32>
|
||||
// CHECK: %[[VAL_96:.*]] = torch.aten.to.dtype %[[VAL_95]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[RAMP_UP:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[UP_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_98:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[IOTA]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,1],si32>, !torch.int -> !torch.vtensor<[9,8],si32>
|
||||
// CHECK: %[[VAL_99:.*]] = torch.aten.to.dtype %[[VAL_98]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[RAMP_DOWN:.*]] = torch.aten.div.Tensor %[[VAL_99]], %[[DOWN_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_101:.*]] = torch.aten.ge.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_102:.*]] = torch.aten.eq.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_103:.*]] = torch.aten.lt.Tensor %[[IOTA]], %[[LOW_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_104:.*]] = torch.aten.gt.Tensor %[[IOTA]], %[[HIGH_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[RAMP_INIT:.*]] = torch.aten.where.self %[[VAL_101]], %[[RAMP_DOWN]], %[[RAMP_UP]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_106:.*]] = torch.aten.where.ScalarSelf %[[VAL_103]], %[[VAL_11]], %[[RAMP_INIT]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_107:.*]] = torch.aten.where.ScalarSelf %[[VAL_104]], %[[VAL_11]], %[[VAL_106]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_108:.*]] = torch.aten.eq.Scalar %[[CENTER_TO_HIGH]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
|
||||
// CHECK: %[[CORNER_CASES:.*]] = torch.aten.logical_and %[[VAL_102]], %[[VAL_108]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[1,8],i1> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[RAMP:.*]] = torch.aten.where.ScalarSelf %[[CORNER_CASES]], %[[VAL_22]], %[[VAL_107]] : !torch.vtensor<[9,8],i1>, !torch.float, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_111:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[OUTPUT:.*]] = torch.aten.to.dtype %[[RAMP]], %[[VAL_111]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: return %[[OUTPUT]] : !torch.vtensor<[9,8],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.MelWeightMatrix"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32>
|
||||
return %0 : !torch.vtensor<[9,8],f32>
|
||||
|
|
Loading…
Reference in New Issue