mirror of https://github.com/llvm/torch-mlir
onnx.MelWeightMatrix TorchOnnxToTorch (#3503)
Just uploading what I have till now
[Gist](https://gist.github.com/PhaneeshB/761f75f5522d9f4a40ef949a328e93fe)
of pytorch impl that I'm following to implement the OnnxToTorch lowering
Additional Details - (also pasted as comment in gist)
[Op
Description](https://github.com/onnx/onnx/blob/main/docs/Operators.md#melweightmatrix)
in Onnx Documentation
[Example](https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-93)
Used the same example in this file.
the Expected output is shown in the example
[Reference Onnx
Impl](4c3ed5e08b/onnx/reference/ops/op_mel_weight_matrix.py (L13)
)
- This is the base for the above code.
pull/3629/head
parent
334633b738
commit
026dfade64
|
@ -591,6 +591,373 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"MelWeightMatrix", 17,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
llvm::SmallVector<Value> operands;
|
||||
Torch::ValueTensorType resultType;
|
||||
int64_t output_dtype_attr;
|
||||
if (binder.tensorOperands(operands, 5) ||
|
||||
binder.tensorResultType(resultType) || operands.size() != 5 ||
|
||||
binder.s64IntegerAttr(output_dtype_attr, "output_datatype", 1)) {
|
||||
return failure();
|
||||
}
|
||||
// operands sequence :
|
||||
// num_mel_bins, dft_length, sample_rate -> int32/64 tensors
|
||||
// lower_edge_hertz, upper_edge_hertz -> f16/32/64
|
||||
|
||||
// Need to backtrack the values of num_mel_bins and dft_length//2+1 from
|
||||
// result shape since the inputs are tensors and we cannot know their
|
||||
// values at compile time. if the result type does not contain static
|
||||
// shapes, then the implementation will be unsupported.
|
||||
if (!resultType.areAllSizesKnown())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unknown result sizes, not supported.");
|
||||
|
||||
ArrayRef<int64_t> resShape = resultType.getSizes();
|
||||
if (resShape.size() != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"Expected result rank to be 2, not supported for other ranks.");
|
||||
|
||||
std::optional<int64_t> torchDTypeInt =
|
||||
onnxDtypeIntToTorchDtypeInt(output_dtype_attr);
|
||||
if (!torchDTypeInt.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "conversion to given output dtype unsupported");
|
||||
}
|
||||
|
||||
// Here Onwards all shapes will be computed using these sizes
|
||||
int64_t numSpectrogramBinsInt = resShape[0];
|
||||
int64_t numMelBinsInt = resShape[1];
|
||||
Torch::ValueTensorType inputIntType = binder.toValidTensorType(
|
||||
operands[0].getType()); // Since operands[0 / 1 / 2] will have the
|
||||
// same int type.
|
||||
Torch::ValueTensorType inputFloatType = binder.toValidTensorType(
|
||||
operands[3].getType()); // Since operands[3 / 4] will have the same
|
||||
// float type
|
||||
|
||||
Value numMelBinsItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, operands[0]);
|
||||
Value dftLengthItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, operands[1]);
|
||||
Value sampleRateItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, operands[2]);
|
||||
Value lowerEdgeHzItem =
|
||||
getItemOp<Torch::FloatType>(binder, rewriter, operands[3]);
|
||||
Value upperEdgeHzItem =
|
||||
getItemOp<Torch::FloatType>(binder, rewriter, operands[4]);
|
||||
|
||||
// Helpers
|
||||
ImplicitLocOpBuilder b(binder.getLoc(), rewriter);
|
||||
auto ctx = binder.op->getContext();
|
||||
|
||||
// Recurring shapes
|
||||
SmallVector<int64_t> unranked({});
|
||||
SmallVector<int64_t> shapeNMB({numMelBinsInt});
|
||||
SmallVector<int64_t> shapeNMBp2({numMelBinsInt + 2});
|
||||
SmallVector<int64_t> shape1xNMB({1, numMelBinsInt});
|
||||
SmallVector<int64_t> shapeNSB({numSpectrogramBinsInt});
|
||||
SmallVector<int64_t> shapeNSBxNMB(
|
||||
{numSpectrogramBinsInt, numMelBinsInt});
|
||||
|
||||
// Recurring DTypes
|
||||
Type inpFpDType = inputFloatType.getDtype();
|
||||
Type inpIntDType = inputIntType.getDtype();
|
||||
Type si32Ty = rewriter.getIntegerType(32, true);
|
||||
Type f32Ty = rewriter.getF32Type();
|
||||
Type i1Ty = rewriter.getI1Type();
|
||||
|
||||
// Value constants
|
||||
Value noneConst = b.create<Torch::ConstantNoneOp>();
|
||||
Value negTwoConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-2));
|
||||
Value negOneConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(-1));
|
||||
Value zeroConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(0));
|
||||
Value oneConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(1));
|
||||
Value twoConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(2));
|
||||
Value float32DTypeConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(6));
|
||||
|
||||
Torch::ValueTensorType dftLenType =
|
||||
Torch::ValueTensorType::get(ctx, unranked, inpIntDType);
|
||||
Type freqBinsIntType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNMBp2, si32Ty);
|
||||
Type freqBinsFltType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNMBp2, f32Ty);
|
||||
|
||||
Value dftLengthDivTwoFlt =
|
||||
b.create<Torch::AtenDivIntOp>(dftLengthItem, twoConst);
|
||||
Value dftLengthDivTwo =
|
||||
b.create<Torch::AtenIntFloatOp>(dftLengthDivTwoFlt);
|
||||
Value numSpectrogramBins =
|
||||
b.create<Torch::AtenAddIntOp>(dftLengthDivTwo, oneConst);
|
||||
Value numSpectrogramBinsItem = numSpectrogramBins;
|
||||
Value freqBinsInit = b.create<Torch::AtenArangeOp>(
|
||||
freqBinsIntType, numMelBinsItem, /*dtype=*/float32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
|
||||
// From Ref Impl of Onnx.MelWeightMatrix:
|
||||
// https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32
|
||||
// convert input Freq Hz to Mel
|
||||
Value twoFiveNineFiveConst =
|
||||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(2595));
|
||||
Value sevenHConst =
|
||||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(700));
|
||||
Value tenConst =
|
||||
b.create<Torch::ConstantFloatOp>(rewriter.getF64FloatAttr(10));
|
||||
|
||||
Value lfDiv7Hfloat =
|
||||
b.create<Torch::AtenDivFloatOp>(lowerEdgeHzItem, sevenHConst);
|
||||
Type freqType = Torch::ValueTensorType::get(ctx, unranked, inpFpDType);
|
||||
Value lfDiv7H =
|
||||
b.create<Torch::PrimNumToTensorScalarOp>(freqType, lfDiv7Hfloat);
|
||||
Value lfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
|
||||
freqType, lfDiv7H, oneConst, /*alpha =*/oneConst);
|
||||
Value lfDiv7HAdd1Log10 =
|
||||
b.create<Torch::AtenLog10Op>(freqType, lfDiv7HAdd1);
|
||||
Value lfMel = b.create<Torch::AtenMulScalarOp>(
|
||||
freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst);
|
||||
|
||||
Value hfDiv7Hfloat =
|
||||
b.create<Torch::AtenDivFloatOp>(upperEdgeHzItem, sevenHConst);
|
||||
Value hfDiv7H =
|
||||
b.create<Torch::PrimNumToTensorScalarOp>(freqType, hfDiv7Hfloat);
|
||||
Value hfDiv7HAdd1 = b.create<Torch::AtenAddScalarOp>(
|
||||
freqType, hfDiv7H, oneConst, /*alpha =*/oneConst);
|
||||
Value hfDiv7HAdd1Log10 =
|
||||
b.create<Torch::AtenLog10Op>(freqType, hfDiv7HAdd1);
|
||||
Value hfMel = b.create<Torch::AtenMulScalarOp>(
|
||||
freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst);
|
||||
|
||||
Value hfSubLf = b.create<Torch::AtenSubTensorOp>(
|
||||
hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst);
|
||||
Value melStep = b.create<Torch::AtenDivScalarOp>(
|
||||
hfSubLf.getType(), hfSubLf, numMelBinsItem);
|
||||
|
||||
Value freqBinsMulMelStep = b.create<Torch::AtenMulTensorOp>(
|
||||
freqBinsFltType, freqBinsInit, melStep);
|
||||
Value freqBinsScaled = b.create<Torch::AtenAddTensorOp>(
|
||||
freqBinsFltType, freqBinsMulMelStep, lfMel, /*alpha=*/oneConst);
|
||||
|
||||
// Mel to Hz conv
|
||||
|
||||
Value fbDiv = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, freqBinsScaled, twoFiveNineFiveConst);
|
||||
Value fbClone = b.create<Torch::AtenCloneOp>(
|
||||
freqBinsFltType, freqBinsScaled, /*memory_format=*/noneConst);
|
||||
Value tenTensor = b.create<Torch::AtenFillScalarOp>(freqBinsFltType,
|
||||
fbClone, tenConst);
|
||||
Value fbPow = b.create<Torch::AtenPowTensorTensorOp>(freqBinsFltType,
|
||||
tenTensor, fbDiv);
|
||||
Value fbPowSubOne = b.create<Torch::AtenSubScalarOp>(
|
||||
freqBinsFltType, fbPow, oneConst, /*alpha=*/oneConst);
|
||||
Value freqBinsHz = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, fbPowSubOne, sevenHConst);
|
||||
|
||||
// Normalize freqBinsHz
|
||||
Value dftLenPlusOne = b.create<Torch::AtenAddScalarOp>(
|
||||
dftLenType, operands[1], oneConst, /*alpha=*/oneConst);
|
||||
Value dftLenPlusOneItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, dftLenPlusOne);
|
||||
Value fbMulDft = b.create<Torch::AtenMulScalarOp>(
|
||||
freqBinsFltType, freqBinsHz, dftLenPlusOneItem);
|
||||
Value freqBinsNormalized = b.create<Torch::AtenDivScalarOp>(
|
||||
freqBinsFltType, fbMulDft, sampleRateItem);
|
||||
|
||||
// cast to int32
|
||||
Value int32DTypeConst =
|
||||
b.create<Torch::ConstantIntOp>(rewriter.getI64IntegerAttr(3));
|
||||
Value falseConst = b.create<Torch::ConstantBoolOp>(false);
|
||||
Value freqBins = b.create<Torch::AtenToDtypeOp>(
|
||||
freqBinsIntType, freqBinsNormalized, /*dtype=*/int32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
|
||||
Torch::ValueTensorType sliceResType =
|
||||
Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty);
|
||||
Type unsqueezeResType =
|
||||
sliceResType.getWithSizesAndDtype(shape1xNMB, si32Ty);
|
||||
Value lfTensor = b.create<Torch::AtenSliceTensorOp>(
|
||||
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
|
||||
/*end=*/negTwoConst, /*step=*/oneConst);
|
||||
Value lowFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeResType, lfTensor, /*dim=*/zeroConst);
|
||||
|
||||
Value cfTensor = b.create<Torch::AtenSliceTensorOp>(
|
||||
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/oneConst,
|
||||
/*end=*/negOneConst, /*step=*/oneConst);
|
||||
Value centerFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeResType, cfTensor, /*dim=*/zeroConst);
|
||||
|
||||
Value hfTensor = b.create<Torch::AtenSliceTensorOp>(
|
||||
sliceResType, freqBins, /*dim=*/zeroConst, /*start=*/zeroConst,
|
||||
/*end=*/noneConst, /*step=*/oneConst);
|
||||
Value highFreqTensor = b.create<Torch::AtenUnsqueezeOp>(
|
||||
unsqueezeResType, hfTensor, /*dim=*/zeroConst);
|
||||
|
||||
Value lowToCenter =
|
||||
b.create<Torch::AtenSubTensorOp>(unsqueezeResType, centerFreqTensor,
|
||||
lowFreqTensor, /*alpha=*/oneConst);
|
||||
Value centerToHigh = b.create<Torch::AtenSubTensorOp>(
|
||||
unsqueezeResType, highFreqTensor, centerFreqTensor,
|
||||
/*alpha=*/oneConst);
|
||||
|
||||
Type zeroToNInitType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSB, f32Ty);
|
||||
Value zeroToNInit = b.create<Torch::AtenArangeOp>(
|
||||
zeroToNInitType, numSpectrogramBinsItem,
|
||||
/*dtype=*/float32DTypeConst,
|
||||
/*layout=*/noneConst, /*device=*/noneConst,
|
||||
/*pin_memory=*/noneConst);
|
||||
|
||||
Type zeroToNBaseType = inputIntType.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>{numSpectrogramBinsInt, 1}, f32Ty);
|
||||
Value zeroToNBase = b.create<Torch::AtenUnsqueezeOp>(
|
||||
zeroToNBaseType, zeroToNInit, /*dim=*/oneConst);
|
||||
Type zeroToNumElesType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty);
|
||||
Value expandShapeList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
SmallVector<Value>{numSpectrogramBinsItem, numMelBinsItem});
|
||||
Value zeroToNumEles = b.create<Torch::AtenExpandOp>(
|
||||
zeroToNumElesType, zeroToNBase, expandShapeList,
|
||||
/*implicit=*/falseConst);
|
||||
|
||||
Type maskType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty);
|
||||
Value maskLowToCenterZero =
|
||||
b.create<Torch::AtenEqScalarOp>(maskType, lowToCenter, zeroConst);
|
||||
|
||||
// L2C computation
|
||||
Value lowToCenterNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
unsqueezeResType, maskLowToCenterZero, negOneConst, lowToCenter);
|
||||
Type maskL2CAfterCType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
|
||||
Value maskL2CAfterC = b.create<Torch::AtenGtTensorOp>(
|
||||
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
|
||||
Type maxLFResTy =
|
||||
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{1}, si32Ty);
|
||||
Value maxLowerFreq =
|
||||
b.create<Torch::AtenMaxOp>(maxLFResTy, lowFreqTensor);
|
||||
Value maxLowerFreqItem =
|
||||
getItemOp<Torch::IntType>(binder, rewriter, maxLowerFreq);
|
||||
Value zeroToNumElesL2C = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
zeroToNumElesType, maskLowToCenterZero, maxLowerFreqItem,
|
||||
zeroToNumEles);
|
||||
Value upslopeDiff = b.create<Torch::AtenSubTensorOp>(
|
||||
zeroToNumElesType, zeroToNumElesL2C, lowFreqTensor,
|
||||
/*alpha=*/oneConst);
|
||||
Type l2cNZFltTy = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty);
|
||||
Value l2cNZFlt = b.create<Torch::AtenToDtypeOp>(
|
||||
l2cNZFltTy, lowToCenterNoZero, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value upslopeL2C0 = b.create<Torch::AtenDivTensorOp>(
|
||||
zeroToNumElesType, upslopeDiff, l2cNZFlt);
|
||||
Type maskUpslopeL2C0PosType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty);
|
||||
Value maskUpslopeL2C0Pos = b.create<Torch::AtenGtScalarOp>(
|
||||
maskUpslopeL2C0PosType, upslopeL2C0, zeroConst);
|
||||
Value upslopeL2C0PosRanged = b.create<Torch::AtenWhereScalarOtherOp>(
|
||||
zeroToNumElesType, maskUpslopeL2C0Pos, upslopeL2C0, zeroConst);
|
||||
Value maskIdxL2CAfterCList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskL2CAfterC.getType()),
|
||||
ValueRange{maskL2CAfterC});
|
||||
Value zeroConstTensor = Torch::createRank0Tensor(
|
||||
rewriter, binder.getLoc(),
|
||||
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), zeroConst);
|
||||
Value upslopeL2C1 = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, upslopeL2C0PosRanged, maskIdxL2CAfterCList,
|
||||
zeroConstTensor, falseConst);
|
||||
Value maskIdxL2CZeroList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskLowToCenterZero.getType()),
|
||||
ValueRange{maskLowToCenterZero});
|
||||
Type centerFreqTensorL2CZeroType =
|
||||
inputIntType.getWithSizesAndDtype(ArrayRef<int64_t>{-1}, si32Ty);
|
||||
Value centerFreqTensorL2CZero = b.create<Torch::AtenIndexTensorOp>(
|
||||
centerFreqTensorL2CZeroType, centerFreqTensor, maskIdxL2CZeroList);
|
||||
Type maskSqueezeType =
|
||||
inputIntType.getWithSizesAndDtype(shapeNMB, i1Ty);
|
||||
Value maskLowToCenterZeroSqueeze = b.create<Torch::AtenSqueezeOp>(
|
||||
maskSqueezeType, maskLowToCenterZero);
|
||||
Type maskL2CIntTy = inputIntType.getWithSizesAndDtype(shapeNMB, si32Ty);
|
||||
Value maskLowToCenterInt = b.create<Torch::AtenToDtypeOp>(
|
||||
maskL2CIntTy, maskLowToCenterZeroSqueeze, /*dtype=*/int32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value upslopeOneIdxList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(
|
||||
centerFreqTensorL2CZero.getType()),
|
||||
ValueRange{centerFreqTensorL2CZero, maskLowToCenterInt});
|
||||
Value oneConstTensor = Torch::createRank0Tensor(
|
||||
rewriter, binder.getLoc(),
|
||||
Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst);
|
||||
Value upslopeL2C = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, upslopeL2C1, upslopeOneIdxList, oneConstTensor,
|
||||
falseConst);
|
||||
|
||||
// H2C computation
|
||||
Value maskCenterToHighZero =
|
||||
b.create<Torch::AtenEqScalarOp>(maskType, centerToHigh, zeroConst);
|
||||
Value maskH2CBeforeC = b.create<Torch::AtenLtTensorOp>(
|
||||
maskL2CAfterCType, zeroToNumEles, centerFreqTensor);
|
||||
Value centerToHighNoZero = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
unsqueezeResType, maskCenterToHighZero, negOneConst, centerToHigh);
|
||||
Value c2hNZFlt = b.create<Torch::AtenToDtypeOp>(
|
||||
l2cNZFltTy, centerToHighNoZero, /*dtype=*/float32DTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
Value zeroToNumElesC2H = b.create<Torch::AtenWhereScalarSelfOp>(
|
||||
zeroToNumElesType, maskCenterToHighZero, zeroConst, zeroToNumEles);
|
||||
Value downslopeDiff = b.create<Torch::AtenSubTensorOp>(
|
||||
zeroToNumElesType, highFreqTensor, zeroToNumElesC2H,
|
||||
/*alpha=*/oneConst);
|
||||
Value downslopeC2H0 = b.create<Torch::AtenDivTensorOp>(
|
||||
zeroToNumElesType, downslopeDiff, c2hNZFlt);
|
||||
Value maskDownslopeC2H0Pos = b.create<Torch::AtenGtScalarOp>(
|
||||
maskUpslopeL2C0PosType, downslopeC2H0, zeroConst);
|
||||
Value downslopeC2H0Pos = b.create<Torch::AtenWhereScalarOtherOp>(
|
||||
zeroToNumElesType, maskDownslopeC2H0Pos, downslopeC2H0, zeroConst);
|
||||
Value idxH2CBeforeCList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskH2CBeforeC.getType()),
|
||||
ValueRange{maskH2CBeforeC});
|
||||
Value downslopeC2H = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, downslopeC2H0Pos, idxH2CBeforeCList,
|
||||
zeroConstTensor, falseConst);
|
||||
|
||||
// final result Calculation
|
||||
Value maskH2CNonZero = b.create<Torch::AtenNeScalarOp>(
|
||||
maskL2CAfterCType, downslopeC2H, zeroConst);
|
||||
Value idxH2CNZList = b.create<Torch::PrimListConstructOp>(
|
||||
rewriter.getType<Torch::ListType>(maskH2CNonZero.getType()),
|
||||
ValueRange{maskH2CNonZero});
|
||||
Value upslopeL2CMasked = b.create<Torch::AtenIndexPutOp>(
|
||||
zeroToNumElesType, upslopeL2C, idxH2CNZList, zeroConstTensor,
|
||||
falseConst);
|
||||
|
||||
Value slopesFinal = b.create<Torch::AtenAddTensorOp>(
|
||||
zeroToNumElesType, upslopeL2CMasked, downslopeC2H,
|
||||
/*alpha=*/oneConst);
|
||||
|
||||
Value outputDTypeConst = b.create<Torch::ConstantIntOp>(
|
||||
rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getI64IntegerAttr(torchDTypeInt.value()));
|
||||
Value finalOutput = b.create<Torch::AtenToDtypeOp>(
|
||||
resultType, slopesFinal, /*dtype=*/outputDTypeConst,
|
||||
/*non_blocking=*/falseConst, /*copy=*/falseConst,
|
||||
/*memory_format=*/noneConst);
|
||||
|
||||
rewriter.replaceOp(binder.op, finalOutput);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"Multinomial", 7,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -1918,3 +1918,119 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
|
|||
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
|
||||
return %0 : !torch.vtensor<[1,3],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_mwm
|
||||
func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "test_mwm", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>, %[[VAL_1:.*]]: !torch.vtensor<[],si64>, %[[VAL_2:.*]]: !torch.vtensor<[],si64>,
|
||||
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>,
|
||||
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_8:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_9:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_10:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[VAL_11:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_12:.*]] = torch.constant.int -2
|
||||
// CHECK: %[[VAL_13:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[VAL_14:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_15:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_16:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_17:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_18:.*]] = torch.aten.div.int %[[VAL_7]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.float
|
||||
// CHECK: %[[VAL_19:.*]] = torch.aten.Int.float %[[VAL_18]] : !torch.float -> !torch.int
|
||||
// CHECK: %[[VAL_20:.*]] = torch.aten.add.int %[[VAL_19]], %[[VAL_15]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[VAL_21:.*]] = torch.aten.arange %[[VAL_6]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],si32>
|
||||
// CHECK: %[[VAL_22:.*]] = torch.constant.float 2.595000e+03
|
||||
// CHECK: %[[VAL_23:.*]] = torch.constant.float 7.000000e+02
|
||||
// CHECK: %[[VAL_24:.*]] = torch.constant.float 1.000000e+01
|
||||
// CHECK: %[[VAL_25:.*]] = torch.aten.div.float %[[VAL_9]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[VAL_26:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_25]] : !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_27:.*]] = torch.aten.add.Scalar %[[VAL_26]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_28:.*]] = torch.aten.log10 %[[VAL_27]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_29:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[VAL_10]], %[[VAL_23]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[VAL_31:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_30]] : !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_33:.*]] = torch.aten.log10 %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[VAL_22]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_35:.*]] = torch.aten.sub.Tensor %[[VAL_34]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_36:.*]] = torch.aten.div.Scalar %[[VAL_35]], %[[VAL_6]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_37:.*]] = torch.aten.mul.Tensor %[[VAL_21]], %[[VAL_36]] : !torch.vtensor<[10],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_38:.*]] = torch.aten.add.Tensor %[[VAL_37]], %[[VAL_29]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_39:.*]] = torch.aten.div.Scalar %[[VAL_38]], %[[VAL_22]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_40:.*]] = torch.aten.clone %[[VAL_38]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_41:.*]] = torch.aten.fill.Scalar %[[VAL_40]], %[[VAL_24]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_42:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_41]], %[[VAL_39]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_43:.*]] = torch.aten.sub.Scalar %[[VAL_42]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[10],f32>, !torch.int, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_44:.*]] = torch.aten.mul.Scalar %[[VAL_43]], %[[VAL_23]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_45:.*]] = torch.aten.add.Scalar %[[VAL_1]], %[[VAL_15]], %[[VAL_15]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[VAL_46:.*]] = torch.aten.item %[[VAL_45]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL_47:.*]] = torch.aten.mul.Scalar %[[VAL_44]], %[[VAL_46]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_48:.*]] = torch.aten.div.Scalar %[[VAL_47]], %[[VAL_8]] : !torch.vtensor<[10],f32>, !torch.int -> !torch.vtensor<[10],f32>
|
||||
// CHECK: %[[VAL_49:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_50:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_51:.*]] = torch.aten.to.dtype %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],si32>
|
||||
// CHECK: %[[VAL_52:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_12]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_53:.*]] = torch.aten.unsqueeze %[[VAL_52]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_54:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_15]], %[[VAL_13]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_55:.*]] = torch.aten.unsqueeze %[[VAL_54]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_56:.*]] = torch.aten.slice.Tensor %[[VAL_51]], %[[VAL_14]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]] : !torch.vtensor<[10],si32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_57:.*]] = torch.aten.unsqueeze %[[VAL_56]], %[[VAL_14]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_58:.*]] = torch.aten.sub.Tensor %[[VAL_55]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_59:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_55]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_60:.*]] = torch.aten.arange %[[VAL_20]], %[[VAL_17]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],f32>
|
||||
// CHECK: %[[VAL_61:.*]] = torch.aten.unsqueeze %[[VAL_60]], %[[VAL_15]] : !torch.vtensor<[9],f32>, !torch.int -> !torch.vtensor<[9,1],f32>
|
||||
// CHECK: %[[VAL_62:.*]] = torch.prim.ListConstruct %[[VAL_20]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_63:.*]] = torch.aten.expand %[[VAL_61]], %[[VAL_62]], %[[VAL_50]] : !torch.vtensor<[9,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_64:.*]] = torch.aten.eq.Scalar %[[VAL_58]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
|
||||
// CHECK: %[[VAL_65:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_13]], %[[VAL_58]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_66:.*]] = torch.aten.gt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_67:.*]] = torch.aten.max %[[VAL_53]] : !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1],si32>
|
||||
// CHECK: %[[VAL_68:.*]] = torch.aten.item %[[VAL_67]] : !torch.vtensor<[1],si32> -> !torch.int
|
||||
// CHECK: %[[VAL_69:.*]] = torch.aten.where.ScalarSelf %[[VAL_64]], %[[VAL_68]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_70:.*]] = torch.aten.sub.Tensor %[[VAL_69]], %[[VAL_53]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_71:.*]] = torch.aten.to.dtype %[[VAL_65]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
|
||||
// CHECK: %[[VAL_72:.*]] = torch.aten.div.Tensor %[[VAL_70]], %[[VAL_71]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_73:.*]] = torch.aten.gt.Scalar %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_74:.*]] = torch.aten.where.ScalarOther %[[VAL_73]], %[[VAL_72]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_75:.*]] = torch.prim.ListConstruct %[[VAL_66]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
|
||||
// CHECK: %[[VAL_76:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[VAL_77:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_78:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_79:.*]] = torch.aten.full %[[VAL_76]], %[[VAL_14]], %[[VAL_78]], %[[VAL_77]], %[[VAL_77]], %[[VAL_77]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_80:.*]] = torch.aten.index_put %[[VAL_74]], %[[VAL_75]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_81:.*]] = torch.prim.ListConstruct %[[VAL_64]] : (!torch.vtensor<[1,8],i1>) -> !torch.list<vtensor<[1,8],i1>>
|
||||
// CHECK: %[[VAL_82:.*]] = torch.aten.index.Tensor %[[VAL_55]], %[[VAL_81]] : !torch.vtensor<[1,8],si32>, !torch.list<vtensor<[1,8],i1>> -> !torch.vtensor<[?],si32>
|
||||
// CHECK: %[[VAL_83:.*]] = torch.aten.squeeze %[[VAL_64]] : !torch.vtensor<[1,8],i1> -> !torch.vtensor<[8],i1>
|
||||
// CHECK: %[[VAL_84:.*]] = torch.aten.to.dtype %[[VAL_83]], %[[VAL_49]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[8],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32>
|
||||
// CHECK: %[[VAL_85:.*]] = torch.prim.ListConstruct %[[VAL_82]], %[[VAL_84]] : (!torch.vtensor<[?],si32>, !torch.vtensor<[8],si32>) -> !torch.list<vtensor<[?],si32>>
|
||||
// CHECK: %[[VAL_86:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[VAL_87:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_88:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_89:.*]] = torch.aten.full %[[VAL_86]], %[[VAL_15]], %[[VAL_88]], %[[VAL_87]], %[[VAL_87]], %[[VAL_87]] : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[VAL_90:.*]] = torch.aten.index_put %[[VAL_80]], %[[VAL_85]], %[[VAL_89]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[?],si32>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_91:.*]] = torch.aten.eq.Scalar %[[VAL_59]], %[[VAL_14]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1>
|
||||
// CHECK: %[[VAL_92:.*]] = torch.aten.lt.Tensor %[[VAL_63]], %[[VAL_55]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_93:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_13]], %[[VAL_59]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32>
|
||||
// CHECK: %[[VAL_94:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_17]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32>
|
||||
// CHECK: %[[VAL_95:.*]] = torch.aten.where.ScalarSelf %[[VAL_91]], %[[VAL_14]], %[[VAL_63]] : !torch.vtensor<[1,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_96:.*]] = torch.aten.sub.Tensor %[[VAL_57]], %[[VAL_95]], %[[VAL_15]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_97:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[VAL_94]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_98:.*]] = torch.aten.gt.Scalar %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_99:.*]] = torch.aten.where.ScalarOther %[[VAL_98]], %[[VAL_97]], %[[VAL_14]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_100:.*]] = torch.prim.ListConstruct %[[VAL_92]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
|
||||
// CHECK: %[[VAL_101:.*]] = torch.aten.index_put %[[VAL_99]], %[[VAL_100]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_102:.*]] = torch.aten.ne.Scalar %[[VAL_101]], %[[VAL_14]] : !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],i1>
|
||||
// CHECK: %[[VAL_103:.*]] = torch.prim.ListConstruct %[[VAL_102]] : (!torch.vtensor<[9,8],i1>) -> !torch.list<vtensor<[9,8],i1>>
|
||||
// CHECK: %[[VAL_104:.*]] = torch.aten.index_put %[[VAL_90]], %[[VAL_103]], %[[VAL_79]], %[[VAL_50]] : !torch.vtensor<[9,8],f32>, !torch.list<vtensor<[9,8],i1>>, !torch.vtensor<[],f32>, !torch.bool -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_105:.*]] = torch.aten.add.Tensor %[[VAL_104]], %[[VAL_101]], %[[VAL_15]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32>, !torch.int -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: %[[VAL_106:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[VAL_107:.*]] = torch.aten.to.dtype %[[VAL_105]], %[[VAL_106]], %[[VAL_50]], %[[VAL_50]], %[[VAL_11]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32>
|
||||
// CHECK: return %[[VAL_107]] : !torch.vtensor<[9,8],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.MelWeightMatrix"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32>
|
||||
return %0 : !torch.vtensor<[9,8],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue