onnx.MelWeightMatrix TorchOnnxToTorch (#3503)

Just uploading what I have till now

[Gist](https://gist.github.com/PhaneeshB/761f75f5522d9f4a40ef949a328e93fe)
of pytorch impl that I'm following to implement the OnnxToTorch lowering

Additional Details - (also pasted as comment in gist)
[Op
Description](https://github.com/onnx/onnx/blob/main/docs/Operators.md#melweightmatrix)
in Onnx Documentation

[Example](https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-93)
Used the same example in this file.
the Expected output is shown in the example

[Reference Onnx
Impl](4c3ed5e08b/onnx/reference/ops/op_mel_weight_matrix.py (L13))
- This is the base for the above code.
pull/3629/head
Phaneesh Barwaria 2024-08-12 08:48:29 -07:00 committed by GitHub
parent 334633b738
commit 026dfade64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 483 additions and 0 deletions

View File

@ -591,6 +591,373 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"MelWeightMatrix", 17,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<Value> operands;
Torch::ValueTensorType resultType;
int64_t output_dtype_attr;
if (binder.tensorOperands(operands, 5) ||
binder.tensorResultType(resultType) || operands.size() != 5 ||
binder.s64IntegerAttr(output_dtype_attr, "output_datatype", 1)) {
return failure();
}
// operands sequence :
// num_mel_bins, dft_length, sample_rate -> int32/64 tensors
// lower_edge_hertz, upper_edge_hertz -> f16/32/64
// Need to backtrack the values of num_mel_bins and dft_length//2+1 from
// result shape since the inputs are tensors and we cannot know their
// values at compile time. if the result type does not contain static
// shapes, then the implementation will be unsupported.
if (!resultType.areAllSizesKnown())
return rewriter.notifyMatchFailure(
binder.op, "Unknown result sizes, not supported.");
ArrayRef<int64_t> resShape = resultType.getSizes();
if (resShape.size() != 2)
return rewriter.notifyMatchFailure(
binder.op,
"Expected result rank to be 2, not supported for other ranks.");
std::optional<int64_t> torchDTypeInt =
onnxDtypeIntToTorchDtypeInt(output_dtype_attr);
if (!torchDTypeInt.has_value()) {
return rewriter.notifyMatchFailure(
binder.op, "conversion to given output dtype unsupported");
}
// Here Onwards all shapes will be computed using these sizes
int64_t numSpectrogramBinsInt = resShape[0];
int64_t numMelBinsInt = resShape[1];
Torch::ValueTensorType inputIntType = binder.toValidTensorType(
operands[0].getType()); // Since operands[0 / 1 / 2] will have the
// same int type.
Torch::ValueTensorType inputFloatType = binder.toValidTensorType(
operands[3].getType()); // Since operands[3 / 4] will have the same
// float type
Value numMelBinsItem =
getItemOp<Torch::IntType>(binder, rewriter, operands[0]);
Value dftLengthItem =
getItemOp<Torch::IntType>(binder, rewriter, operands[1]);
Value sampleRateItem =
getItemOp<Torch::IntType>(binder, rewriter, operands[2]);
Value lowerEdgeHzItem =
getItemOp<Torch::FloatType>(binder, rewriter, operands[3]);
Value upperEdgeHzItem =
getItemOp<Torch::FloatType>(binder, rewriter, operands[4]);
// Helpers
ImplicitLocOpBuilder b(binder.getLoc(), rewriter);
auto ctx = binder.op->getContext();
// Recurring shapes
SmallVector<int64_t> unranked({});
SmallVector<int64_t> shapeNMB({numMelBinsInt});
SmallVector<int64_t> shapeNMBp2({numMelBinsInt + 2});
SmallVector<int64_t> shape1xNMB({1, numMelBinsInt});
SmallVector<int64_t> shapeNSB({numSpectrogramBinsInt});
SmallVector<int64_t> shapeNSBxNMB(
{numSpectrogramBinsInt, numMelBinsInt});
// Recurring DTypes
Type inpFpDType = inputFloatType.getDtype();
Type inpIntDType = inputIntType.getDtype();
Type si32Ty = rewriter.getIntegerType(32, true);
Type f32Ty = rewriter.getF32Type();
Type i1Ty = rewriter.getI1Type();
// Value constants
Value noneConst = b.create<Torch::ConstantNoneOp>();
Value negTwoConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-2));
Value negOneConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-1));
Value zeroConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(0));
Value oneConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(1));
Value twoConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(2));
Value float32DTypeConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(6));
Torch::ValueTensorType dftLenType =
Torch::ValueTensorType::get(ctx, unranked, inpIntDType);
Type freqBinsIntType =
Torch::ValueTensorType::get(ctx, shapeNMBp2, si32Ty);
Type freqBinsFltType =
Torch::ValueTensorType::get(ctx, shapeNMBp2, f32Ty);
Value dftLengthDivTwoFlt =
b.create<Torch::AtenDivIntOp>(dftLengthItem, twoConst);
Value dftLengthDivTwo =
b.create<Torch::AtenIntFloatOp>(dftLengthDivTwoFlt);
Value numSpectrogramBins =
b.create<Torch::AtenAddIntOp>(dftLengthDivTwo, oneConst);
Value numSpectrogramBinsItem = numSpectrogramBins;
Value freqBinsInit = b.create<Torch::AtenArangeOp>(
freqBinsIntType, numMelBinsItem, /*dtype=*/float32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
// From Ref Impl of Onnx.MelWeightMatrix:
// https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32
// convert input Freq Hz to Mel
Value twoFiveNineFiveConst =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2595));
Value sevenHConst =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(700));
Value tenConst =
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(10));
Value lfDiv7Hfloat =
b.create<Torch::AtenDivFloatOp>(lowerEdgeHzItem, sevenHConst);
Type freqType = Torch::ValueTensorType::get(ctx, unranked, inpFpDType);
Value lfDiv7H =
b.create<Torch::PrimNumToTensorScalarOp>(freqType, lfDiv7Hfloat);
Value lfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
freqType, lfDiv7H, oneConst, /*alpha =*/oneConst);
Value lfDiv7HAdd1Log10 =
b.create<Torch::AtenLog10Op>(freqType, lfDiv7HAdd1);
Value lfMel = b.create<Torch::AtenMulScalarOp>(
freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst);
Value hfDiv7Hfloat =
b.create<Torch::AtenDivFloatOp>(upperEdgeHzItem, sevenHConst);
Value hfDiv7H =
b.create<Torch::PrimNumToTensorScalarOp>(freqType, hfDiv7Hfloat);
Value hfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
freqType, hfDiv7H, oneConst, /*alpha =*/oneConst);
Value hfDiv7HAdd1Log10 =
b.create<Torch::AtenLog10Op>(freqType, hfDiv7HAdd1);
Value hfMel = b.create<Torch::AtenMulScalarOp>(
freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst);
Value hfSubLf = b.create<Torch::AtenSubTensorOp>(
hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst);
Value melStep = b.create<Torch::AtenDivScalarOp>(
hfSubLf.getType(), hfSubLf, numMelBinsItem);
Value freqBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
freqBinsFltType, freqBinsInit, melStep);
Value freqBinsScaled = b.create<Torch::AtenAddTensorOp>(
freqBinsFltType, freqBinsMulMelStep, lfMel, /*alpha=*/oneConst);
// Mel to Hz conv
Value fbDiv = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst);
Value fbClone = b.create<Torch::AtenCloneOp>(
freqBinsFltType, freqBinsScaled, /*memory_format=*/noneConst);
Value tenTensor = b.create<Torch::AtenFillScalarOp>(freqBinsFltType,
fbClone, tenConst);
Value fbPow = b.create<Torch::AtenPowTensorTensorOp>(freqBinsFltType,
tenTensor, fbDiv);
Value fbPowSubOne = b.create<Torch::AtenSubScalarOp>(
freqBinsFltType, fbPow, oneConst, /*alpha=*/oneConst);
Value freqBinsHz = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, fbPowSubOne, sevenHConst);
// Normalize freqBinsHz
Value dftLenPlusOne = b.create<Torch::AtenAddScalarOp>(
dftLenType, operands[1], oneConst, /*alpha=*/oneConst);
Value dftLenPlusOneItem =
getItemOp<Torch::IntType>(binder, rewriter, dftLenPlusOne);
Value fbMulDft = b.create<Torch::AtenMulScalarOp>(
freqBinsFltType, freqBinsHz, dftLenPlusOneItem);
Value freqBinsNormalized = b.create<Torch::AtenDivScalarOp>(
freqBinsFltType, fbMulDft, sampleRateItem);
// cast to int32
Value int32DTypeConst =
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
Value falseConst = b.create<Torch::ConstantBoolOp>(false);
Value freqBins = b.create<Torch::AtenToDtypeOp>(
freqBinsIntType, freqBinsNormalized, /*dtype=*/int32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Torch::ValueTensorType sliceResType =
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
Type unsqueezeResType =
sliceResType.getWithSizesAndDtype(shape1xNMB, si32Ty);
Value lfTensor = b.create<Torch::AtenSliceTensorOp>(
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
/*end=*/negTwoConst, /*step=*/oneConst);
Value lowFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeResType, lfTensor, /*dim=*/zeroConst);
Value cfTensor = b.create<Torch::AtenSliceTensorOp>(
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/oneConst,
/*end=*/negOneConst, /*step=*/oneConst);
Value centerFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeResType, cfTensor, /*dim=*/zeroConst);
Value hfTensor = b.create<Torch::AtenSliceTensorOp>(
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
/*end=*/noneConst, /*step=*/oneConst);
Value highFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
unsqueezeResType, hfTensor, /*dim=*/zeroConst);
Value lowToCenter =
b.create<Torch::AtenSubTensorOp>(unsqueezeResType, centerFreqTensor,
lowFreqTensor, /*alpha=*/oneConst);
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
unsqueezeResType, highFreqTensor, centerFreqTensor,
/*alpha=*/oneConst);
Type zeroToNInitType =
inputIntType.getWithSizesAndDtype(shapeNSB, f32Ty);
Value zeroToNInit = b.create<Torch::AtenArangeOp>(
zeroToNInitType, numSpectrogramBinsItem,
/*dtype=*/float32DTypeConst,
/*layout=*/noneConst, /*device=*/noneConst,
/*pin_memory=*/noneConst);
Type zeroToNBaseType = inputIntType.getWithSizesAndDtype(
ArrayRef<int64_t>{numSpectrogramBinsInt, 1}, f32Ty);
Value zeroToNBase = b.create<Torch::AtenUnsqueezeOp>(
zeroToNBaseType, zeroToNInit, /*dim=*/oneConst);
Type zeroToNumElesType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
Value expandShapeList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{numSpectrogramBinsItem, numMelBinsItem});
Value zeroToNumEles = b.create<Torch::AtenExpandOp>(
zeroToNumElesType, zeroToNBase, expandShapeList,
/*implicit=*/falseConst);
Type maskType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
Value maskLowToCenterZero =
b.create<Torch::AtenEqScalarOp>(maskType, lowToCenter, zeroConst);
// L2C computation
Value lowToCenterNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter);
Type maskL2CAfterCType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
Value maskL2CAfterC = b.create<Torch::AtenGtTensorOp>(
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
Type maxLFResTy =
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{1}, si32Ty);
Value maxLowerFreq =
b.create<Torch::AtenMaxOp>(maxLFResTy, lowFreqTensor);
Value maxLowerFreqItem =
getItemOp<Torch::IntType>(binder, rewriter, maxLowerFreq);
Value zeroToNumElesL2C = b.create<Torch::AtenWhereScalarSelfOp>(
zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem,
zeroToNumEles);
Value upslopeDiff = b.create<Torch::AtenSubTensorOp>(
zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor,
/*alpha=*/oneConst);
Type l2cNZFltTy = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
Value l2cNZFlt = b.create<Torch::AtenToDtypeOp>(
l2cNZFltTy, lowToCenterNoZero, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value upslopeL2C0 = b.create<Torch::AtenDivTensorOp>(
zeroToNumElesType, upslopeDiff, l2cNZFlt);
Type maskUpslopeL2C0PosType =
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
Value maskUpslopeL2C0Pos = b.create<Torch::AtenGtScalarOp>(
maskUpslopeL2C0PosType, upslopeL2C0, zeroConst);
Value upslopeL2C0PosRanged = b.create<Torch::AtenWhereScalarOtherOp>(
zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst);
Value maskIdxL2CAfterCList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskL2CAfterC.getType()),
ValueRange{maskL2CAfterC});
Value zeroConstTensor = Torch::createRank0Tensor(
rewriter, binder.getLoc(),
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), zeroConst);
Value upslopeL2C1 = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList,
zeroConstTensor, falseConst);
Value maskIdxL2CZeroList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskLowToCenterZero.getType()),
ValueRange{maskLowToCenterZero});
Type centerFreqTensorL2CZeroType =
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{-1}, si32Ty);
Value centerFreqTensorL2CZero = b.create<Torch::AtenIndexTensorOp>(
centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList);
Type maskSqueezeType =
inputIntType.getWithSizesAndDtype(shapeNMB, i1Ty);
Value maskLowToCenterZeroSqueeze = b.create<Torch::AtenSqueezeOp>(
maskSqueezeType, maskLowToCenterZero);
Type maskL2CIntTy = inputIntType.getWithSizesAndDtype(shapeNMB, si32Ty);
Value maskLowToCenterInt = b.create<Torch::AtenToDtypeOp>(
maskL2CIntTy, maskLowToCenterZeroSqueeze, /*dtype=*/int32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value upslopeOneIdxList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(
centerFreqTensorL2CZero.getType()),
ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt});
Value oneConstTensor = Torch::createRank0Tensor(
rewriter, binder.getLoc(),
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst);
Value upslopeL2C = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor,
falseConst);
// H2C computation
Value maskCenterToHighZero =
b.create<Torch::AtenEqScalarOp>(maskType, centerToHigh, zeroConst);
Value maskH2CBeforeC = b.create<Torch::AtenLtTensorOp>(
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
Value centerToHighNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh);
Value c2hNZFlt = b.create<Torch::AtenToDtypeOp>(
l2cNZFltTy, centerToHighNoZero, /*dtype=*/float32DTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
Value zeroToNumElesC2H = b.create<Torch::AtenWhereScalarSelfOp>(
zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles);
Value downslopeDiff = b.create<Torch::AtenSubTensorOp>(
zeroToNumElesType, highFreqTensor, zeroToNumElesC2H,
/*alpha=*/oneConst);
Value downslopeC2H0 = b.create<Torch::AtenDivTensorOp>(
zeroToNumElesType, downslopeDiff, c2hNZFlt);
Value maskDownslopeC2H0Pos = b.create<Torch::AtenGtScalarOp>(
maskUpslopeL2C0PosType, downslopeC2H0, zeroConst);
Value downslopeC2H0Pos = b.create<Torch::AtenWhereScalarOtherOp>(
zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst);
Value idxH2CBeforeCList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskH2CBeforeC.getType()),
ValueRange{maskH2CBeforeC});
Value downslopeC2H = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList,
zeroConstTensor, falseConst);
// final result Calculation
Value maskH2CNonZero = b.create<Torch::AtenNeScalarOp>(
maskL2CAfterCType, downslopeC2H, zeroConst);
Value idxH2CNZList = b.create<Torch::PrimListConstructOp>(
rewriter.getType<Torch::ListType>(maskH2CNonZero.getType()),
ValueRange{maskH2CNonZero});
Value upslopeL2CMasked = b.create<Torch::AtenIndexPutOp>(
zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor,
falseConst);
Value slopesFinal = b.create<Torch::AtenAddTensorOp>(
zeroToNumElesType, upslopeL2CMasked, downslopeC2H,
/*alpha=*/oneConst);
Value outputDTypeConst = b.create<Torch::ConstantIntOp>(
rewriter.getType<Torch::IntType>(),
rewriter.getI64IntegerAttr(torchDTypeInt.value()));
Value finalOutput = b.create<Torch::AtenToDtypeOp>(
resultType, slopesFinal, /*dtype=*/outputDTypeConst,
/*non_blocking=*/falseConst, /*copy=*/falseConst,
/*memory_format=*/noneConst);
rewriter.replaceOp(binder.op, finalOutput);
return success();
});
patterns.onOp(
"Multinomial", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -1918,3 +1918,119 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
return %0 : !torch.vtensor<[1,3],si64>
}
// -----
// CHECK-LABEL: func.func @test_mwm
func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "test_mwm", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>, %[[VAL_1:.*]]: !torch.vtensor<[],si64>, %[[VAL_2:.*]]: !torch.vtensor<[],si64>,
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>,
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[],f32>
// CHECK: %[[VAL_5:.*]] = torch.constant.none
// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_7:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_9:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_10:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_11:.*]] = torch.constant.none
// CHECK: %[[VAL_12:.*]] = torch.constant.int -2
// CHECK: %[[VAL_13:.*]] = torch.constant.int -1
// CHECK: %[[VAL_14:.*]] = torch.constant.int 0
// CHECK: %[[VAL_15:.*]] = torch.constant.int 1
// CHECK: %[[VAL_16:.*]] = torch.constant.int 2
// CHECK: %[[VAL_17:.*]] = torch.constant.int 6
// CHECK: %[[VAL_18:.*]] = torch.aten.div.int %[[VAL_7]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[VAL_19:.*]] = torch.aten.Int.float %[[VAL_18]] : !torch.float -> !torch.int
// CHECK: %[[VAL_20:.*]] = torch.aten.add.int %[[VAL_19]], %[[VAL_15]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAL_21:.*]] = torch.aten.arange %[[VAL_6]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si32>
// CHECK: %[[VAL_22:.*]] = torch.constant.float 2.595000e+03
// CHECK: %[[VAL_23:.*]] = torch.constant.float 7.000000e+02
// CHECK: %[[VAL_24:.*]] = torch.constant.float 1.000000e+01
// CHECK: %[[VAL_25:.*]] = torch.aten.div.float %[[VAL_9]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[VAL_26:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_25]] : !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_27:.*]] = torch.aten.add.Scalar %[[VAL_26]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_28:.*]] = torch.aten.log10 %[[VAL_27]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_29:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[VAL_10]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
// CHECK: %[[VAL_31:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_30]] : !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_33:.*]] = torch.aten.log10 %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_35:.*]] = torch.aten.sub.Tensor %[[VAL_34]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_36:.*]] = torch.aten.div.Scalar %[[VAL_35]], %[[VAL_6]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_37:.*]] = torch.aten.mul.Tensor %[[VAL_21]], %[[VAL_36]] : !torch.vtensor<[10],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_38:.*]] = torch.aten.add.Tensor %[[VAL_37]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_39:.*]] = torch.aten.div.Scalar %[[VAL_38]], %[[VAL_22]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_40:.*]] = torch.aten.clone %[[VAL_38]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_41:.*]] = torch.aten.fill.Scalar %[[VAL_40]], %[[VAL_24]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_42:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_41]], %[[VAL_39]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_43:.*]] = torch.aten.sub.Scalar %[[VAL_42]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_44:.*]] = torch.aten.mul.Scalar %[[VAL_43]], %[[VAL_23]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_45:.*]] = torch.aten.add.Scalar %[[VAL_1]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_46:.*]] = torch.aten.item %[[VAL_45]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_47:.*]] = torch.aten.mul.Scalar %[[VAL_44]], %[[VAL_46]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_48:.*]] = torch.aten.div.Scalar %[[VAL_47]], %[[VAL_8]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
// CHECK: %[[VAL_49:.*]] = torch.constant.int 3
// CHECK: %[[VAL_50:.*]] = torch.constant.bool false
// CHECK: %[[VAL_51:.*]] = torch.aten.to.dtype %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],si32>
// CHECK: %[[VAL_52:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_12]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_53:.*]] = torch.aten.unsqueeze %[[VAL_52]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_54:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_15]], %[[VAL_13]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_55:.*]] = torch.aten.unsqueeze %[[VAL_54]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_56:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_57:.*]] = torch.aten.unsqueeze %[[VAL_56]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_58:.*]] = torch.aten.sub.Tensor %[[VAL_55]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_59:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_55]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_60:.*]] = torch.aten.arange %[[VAL_20]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],f32>
// CHECK: %[[VAL_61:.*]] = torch.aten.unsqueeze %[[VAL_60]], %[[VAL_15]] : !torch.vtensor<[9],f32>, !torch.int -> !torch.vtensor<[9,1],f32>
// CHECK: %[[VAL_62:.*]] = torch.prim.ListConstruct %[[VAL_20]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_63:.*]] = torch.aten.expand %[[VAL_61]], %[[VAL_62]], %[[VAL_50]] : !torch.vtensor<[9,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_64:.*]] = torch.aten.eq.Scalar %[[VAL_58]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
// CHECK: %[[VAL_65:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_13]], %[[VAL_58]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_66:.*]] = torch.aten.gt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_67:.*]] = torch.aten.max %[[VAL_53]] : !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1],si32>
// CHECK: %[[VAL_68:.*]] = torch.aten.item %[[VAL_67]] : !torch.vtensor<[1],si32> -> !torch.int
// CHECK: %[[VAL_69:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_68]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_70:.*]] = torch.aten.sub.Tensor %[[VAL_69]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_71:.*]] = torch.aten.to.dtype %[[VAL_65]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAL_72:.*]] = torch.aten.div.Tensor %[[VAL_70]], %[[VAL_71]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_73:.*]] = torch.aten.gt.Scalar %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_74:.*]] = torch.aten.where.ScalarOther %[[VAL_73]], %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_75:.*]] = torch.prim.ListConstruct %[[VAL_66]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
// CHECK: %[[VAL_76:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAL_77:.*]] = torch.constant.none
// CHECK: %[[VAL_78:.*]] = torch.constant.int 6
// CHECK: %[[VAL_79:.*]] = torch.aten.full %[[VAL_76]], %[[VAL_14]], %[[VAL_78]], %[[VAL_77]], %[[VAL_77]], %[[VAL_77]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_80:.*]] = torch.aten.index_put %[[VAL_74]], %[[VAL_75]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_81:.*]] = torch.prim.ListConstruct %[[VAL_64]] : (!torch.vtensor<[1,8],i1>) -> !torch.list<vtensor<[1,8],i1>>
// CHECK: %[[VAL_82:.*]] = torch.aten.index.Tensor %[[VAL_55]], %[[VAL_81]] : !torch.vtensor<[1,8],si32>, !torch.list<vtensor<[1,8],i1>> -> !torch.vtensor<[?],si32>
// CHECK: %[[VAL_83:.*]] = torch.aten.squeeze %[[VAL_64]] : !torch.vtensor<[1,8],i1> -> !torch.vtensor<[8],i1>
// CHECK: %[[VAL_84:.*]] = torch.aten.to.dtype %[[VAL_83]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[8],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
// CHECK: %[[VAL_85:.*]] = torch.prim.ListConstruct %[[VAL_82]], %[[VAL_84]] : (!torch.vtensor<[?],si32>, !torch.vtensor<[8],si32>) -> !torch.list<vtensor<[?],si32>>
// CHECK: %[[VAL_86:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[VAL_87:.*]] = torch.constant.none
// CHECK: %[[VAL_88:.*]] = torch.constant.int 6
// CHECK: %[[VAL_89:.*]] = torch.aten.full %[[VAL_86]], %[[VAL_15]], %[[VAL_88]], %[[VAL_87]], %[[VAL_87]], %[[VAL_87]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_90:.*]] = torch.aten.index_put %[[VAL_80]], %[[VAL_85]], %[[VAL_89]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[?],si32>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_91:.*]] = torch.aten.eq.Scalar %[[VAL_59]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
// CHECK: %[[VAL_92:.*]] = torch.aten.lt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_93:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_13]], %[[VAL_59]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
// CHECK: %[[VAL_94:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
// CHECK: %[[VAL_95:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_14]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_96:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_95]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_97:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[VAL_94]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_98:.*]] = torch.aten.gt.Scalar %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_99:.*]] = torch.aten.where.ScalarOther %[[VAL_98]], %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_100:.*]] = torch.prim.ListConstruct %[[VAL_92]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
// CHECK: %[[VAL_101:.*]] = torch.aten.index_put %[[VAL_99]], %[[VAL_100]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_102:.*]] = torch.aten.ne.Scalar %[[VAL_101]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
// CHECK: %[[VAL_103:.*]] = torch.prim.ListConstruct %[[VAL_102]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
// CHECK: %[[VAL_104:.*]] = torch.aten.index_put %[[VAL_90]], %[[VAL_103]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_105:.*]] = torch.aten.add.Tensor %[[VAL_104]], %[[VAL_101]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
// CHECK: %[[VAL_106:.*]] = torch.constant.int 6
// CHECK: %[[VAL_107:.*]] = torch.aten.to.dtype %[[VAL_105]], %[[VAL_106]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
// CHECK: return %[[VAL_107]] : !torch.vtensor<[9,8],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.MelWeightMatrix"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32>
return %0 : !torch.vtensor<[9,8],f32>
}