mirror of https://github.com/llvm/torch-mlir
[ONNX] add support for tfidfvectorizer (#3553)
1-d/2-d input and output implemented based on the description and example test cases in https://github.com/onnx/onnx/blob/main/docs/Operators.md#TfIdfVectorizer and some notes from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_tfidf_vectorizer.py#L128 --------- Co-authored-by: zjgarvey <zjgarvey@gmail.com>pull/3629/head
parent
d3695a97a0
commit
a4ba02eef5
|
@ -4305,6 +4305,308 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
uniqueResults[1], uniqueResults[2]});
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"TfIdfVectorizer", 9,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
llvm::SmallVector<int64_t> ngram_counts;
|
||||
llvm::SmallVector<int64_t> ngram_indexes;
|
||||
llvm::SmallVector<int64_t> pool_int64s;
|
||||
std::string mode;
|
||||
int64_t min_gram_length;
|
||||
int64_t max_gram_length;
|
||||
int64_t max_skip_count;
|
||||
Value input;
|
||||
Torch::ValueTensorType resultType;
|
||||
|
||||
if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) ||
|
||||
binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) ||
|
||||
binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) ||
|
||||
binder.customOpNameStringAttr(mode, "mode", "") ||
|
||||
binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) ||
|
||||
binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) ||
|
||||
binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) ||
|
||||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
if (mode != "TF")
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"TF mode supported only");
|
||||
if (pool_int64s.size() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "pool_int64s empty, only integers supported");
|
||||
auto inputType = dyn_cast<Torch::ValueTensorType>(input.getType());
|
||||
auto inputSizes =
|
||||
dyn_cast<Torch::ValueTensorType>(input.getType()).getSizes();
|
||||
SmallVector<int64_t> inputShape(inputSizes);
|
||||
bool is_2d = (inputShape.size() > 1) ? true : false;
|
||||
if (is_2d && inputShape[0] == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "input batch dimension cannot be dynamic");
|
||||
int batch_size = (is_2d) ? inputShape[0] : 1;
|
||||
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(false));
|
||||
|
||||
auto intType = rewriter.getType<Torch::IntType>();
|
||||
Value loopConditionTrue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(true));
|
||||
Type loopIndexType = intType;
|
||||
// create a zero tensor for output
|
||||
SmallVector<int64_t> resultShape(resultType.getSizes());
|
||||
int64_t rank = resultShape.size();
|
||||
SmallVector<Value> zerosShapeValues;
|
||||
for (int j = 0; j < rank; j++) {
|
||||
Value dimSize = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(resultShape[j]));
|
||||
zerosShapeValues.push_back(dimSize);
|
||||
}
|
||||
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
zerosShapeValues);
|
||||
Value output = rewriter.create<Torch::AtenZerosOp>(
|
||||
binder.getLoc(), resultType, zerosShapeList, none, none, none,
|
||||
none);
|
||||
|
||||
Value batchSize = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(batch_size));
|
||||
auto batchLoop = rewriter.create<Torch::PrimLoopOp>(
|
||||
binder.getLoc(), TypeRange({output.getType()}), batchSize,
|
||||
loopConditionTrue, ValueRange({output}));
|
||||
{
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Block *batchLoopBody = rewriter.createBlock(
|
||||
&batchLoop.getRegion(), batchLoop.getRegion().begin(),
|
||||
TypeRange({loopIndexType, output.getType()}),
|
||||
{binder.getLoc(), binder.getLoc()});
|
||||
Value batchValue = batchLoopBody->getArgument(0);
|
||||
Value output = batchLoopBody->getArgument(1);
|
||||
Value outputForBatch = output;
|
||||
Value inputSequence = input;
|
||||
if (is_2d) {
|
||||
// get input sequence from input (ex: [[0,1],[2,3]] -> [[0,1]] ->
|
||||
// [0,1])
|
||||
SmallVector<int64_t> inputSequenceShape;
|
||||
inputSequenceShape.push_back(1);
|
||||
inputSequenceShape.push_back(inputShape[1]);
|
||||
auto inputSequenceType = rewriter.getType<Torch::ValueTensorType>(
|
||||
inputSequenceShape, inputType.getOptionalDtype());
|
||||
Value batchPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), batchValue, one);
|
||||
inputSequence = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), inputSequenceType, input, /*dim=*/zero,
|
||||
batchValue, batchPlusOne, one);
|
||||
inputSequence = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{inputShape[1]},
|
||||
inputType.getOptionalDtype()),
|
||||
inputSequence, zero);
|
||||
|
||||
SmallVector<int64_t> outputForBatchShape;
|
||||
outputForBatchShape.push_back(1);
|
||||
outputForBatchShape.push_back(resultShape[1]);
|
||||
auto outputForBatchType = rewriter.getType<Torch::ValueTensorType>(
|
||||
outputForBatchShape, resultType.getOptionalDtype());
|
||||
outputForBatch = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), outputForBatchType, output,
|
||||
/*dim=*/zero, batchValue, batchPlusOne, one);
|
||||
outputForBatch = rewriter.create<Torch::AtenSqueezeDimOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ValueTensorType::get(binder.op->getContext(),
|
||||
ArrayRef<int64_t>{resultShape[1]},
|
||||
resultType.getOptionalDtype()),
|
||||
outputForBatch, zero);
|
||||
}
|
||||
// ngram_counts[j] records the starting position of ngrams within the
|
||||
// pool_int64's of length j+1. The loop below is iterating through the
|
||||
// different n-gram sizes
|
||||
// ngram_i keeps track of which ngram we are looking at in the pool.
|
||||
// The frequency of this ngram will be stored in the output tensor at
|
||||
// the position ngram_indexes[ngram_i]
|
||||
int ngram_i = 0;
|
||||
for (int j = 0; j < (int)ngram_counts.size(); j++) {
|
||||
int ngram_length = j + 1;
|
||||
int start_idx = ngram_counts[j];
|
||||
int end_idx = (j + 1) < (int)ngram_counts.size()
|
||||
? ngram_counts[j + 1]
|
||||
: pool_int64s.size();
|
||||
if (j + 1 < min_gram_length || j + 1 > max_gram_length) {
|
||||
// progress the ngram counter for the skipped (j+1)grams
|
||||
ngram_i += (end_idx - start_idx) / ngram_length;
|
||||
continue;
|
||||
}
|
||||
|
||||
Value ngramLength = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(ngram_length));
|
||||
for (int start = start_idx; start < end_idx;
|
||||
start += ngram_length, ngram_i++) {
|
||||
Value count = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
// for 1-grams, there is no skipping (skip = gap between
|
||||
// consecutive values in the n-gram pulled from the input
|
||||
// sequence), so we default to skip_count_bound = 1 in that case
|
||||
// to avoid repeating the same count multiple times.
|
||||
int skip_count_bound =
|
||||
(ngram_length == 1) ? 1 : (max_skip_count + 1);
|
||||
Value skipCountBound = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), intType,
|
||||
rewriter.getI64IntegerAttr(skip_count_bound));
|
||||
// given a n-gram to search for, and the input sequence to search
|
||||
// in, we need to count how many times that n-gram appears in the
|
||||
// input for each skip between 0 and max_skip_count (inclusive).
|
||||
auto skipLoop = rewriter.create<Torch::PrimLoopOp>(
|
||||
binder.getLoc(), TypeRange({count.getType()}), skipCountBound,
|
||||
loopConditionTrue, ValueRange({count}));
|
||||
{
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Block *skipLoopBody = rewriter.createBlock(
|
||||
&skipLoop.getRegion(), skipLoop.getRegion().begin(),
|
||||
TypeRange({loopIndexType, count.getType()}),
|
||||
{binder.getLoc(), binder.getLoc()});
|
||||
Value skipCount = skipLoopBody->getArgument(0);
|
||||
Value skipCountPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), skipCount, one);
|
||||
count = skipLoopBody->getArgument(1);
|
||||
|
||||
// max_start_index =
|
||||
// inputSizes.back() - ((ngram_length - 1) * (skip_count + 1));
|
||||
// the index one higher than the last possible start index
|
||||
// without the input ngram going out of bounds
|
||||
Value seqLen = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), intType,
|
||||
rewriter.getI64IntegerAttr(inputSizes.back()));
|
||||
Value ngramLengthMinusOne =
|
||||
rewriter.create<Torch::AtenSubIntOp>(binder.getLoc(),
|
||||
ngramLength, one);
|
||||
Value ngramSkipLength = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), ngramLengthMinusOne, skipCountPlusOne);
|
||||
Value maxStartIndex = rewriter.create<Torch::AtenSubIntOp>(
|
||||
binder.getLoc(), seqLen, ngramSkipLength);
|
||||
// This loop will extract each n-gram with the given skip_count
|
||||
// from the input sequence from start input index, and increment
|
||||
// the count if the n-gram matches the one gotten from the
|
||||
// pool_int64s
|
||||
auto countLoop = rewriter.create<Torch::PrimLoopOp>(
|
||||
binder.getLoc(), TypeRange({count.getType()}),
|
||||
maxStartIndex, loopConditionTrue, ValueRange({count}));
|
||||
{
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
Block *countLoopBody = rewriter.createBlock(
|
||||
&countLoop.getRegion(), countLoop.getRegion().begin(),
|
||||
TypeRange({loopIndexType, count.getType()}),
|
||||
{binder.getLoc(), binder.getLoc()});
|
||||
|
||||
Value startInputIdx = countLoopBody->getArgument(0);
|
||||
count = countLoopBody->getArgument(1);
|
||||
|
||||
// extract input ngram and compare to pool ngram
|
||||
Torch::BaseTensorType inputSequenceType =
|
||||
cast<Torch::BaseTensorType>(inputSequence.getType());
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
Type selectResultType =
|
||||
inputSequenceType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(selectSizes),
|
||||
inputSequenceType.getOptionalDtype());
|
||||
Value foundNgram = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
for (int i = 0; i < ngram_length; i++) {
|
||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
i));
|
||||
selectIndex = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), selectIndex, skipCountPlusOne);
|
||||
selectIndex = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), selectIndex, startInputIdx);
|
||||
Value inputExtract =
|
||||
rewriter.create<Torch::AtenSelectIntOp>(
|
||||
binder.getLoc(), selectResultType, inputSequence,
|
||||
zero, selectIndex);
|
||||
Value inputNgram_i = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
inputExtract);
|
||||
|
||||
Value poolNgram_i = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(pool_int64s[start + i]));
|
||||
Value isEqual = rewriter.create<Torch::AtenEqIntOp>(
|
||||
binder.getLoc(), inputNgram_i, poolNgram_i);
|
||||
isEqual = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), isEqual);
|
||||
foundNgram = rewriter.create<Torch::AtenMulIntOp>(
|
||||
binder.getLoc(), isEqual, foundNgram);
|
||||
}
|
||||
|
||||
count = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), count, foundNgram);
|
||||
rewriter.create<Torch::PrimLoopConditionOp>(
|
||||
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
||||
}
|
||||
count = countLoop.getResult(0);
|
||||
rewriter.create<Torch::PrimLoopConditionOp>(
|
||||
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
||||
}
|
||||
count = skipLoop.getResult(0);
|
||||
// insert count "tf" into output
|
||||
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
||||
binder.getLoc(), count);
|
||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::FloatType>()),
|
||||
SmallVector<Value>{countFloat});
|
||||
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
(int)torch_upstream::ScalarType::Float));
|
||||
SmallVector<int64_t> countShape{1};
|
||||
auto countType = rewriter.getType<Torch::ValueTensorType>(
|
||||
countShape, resultType.getOptionalDtype());
|
||||
Value countTensor = rewriter.create<Torch::AtenTensorOp>(
|
||||
binder.getLoc(), countType, dataList, /*dtype=*/cstDtype,
|
||||
/*layout=*/none, /*requires_grad=*/cstFalse);
|
||||
|
||||
Value insertStart = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(ngram_indexes[ngram_i]));
|
||||
Value insertEnd = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), insertStart, one);
|
||||
outputForBatch = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||
binder.getLoc(), outputForBatch.getType(), outputForBatch,
|
||||
countTensor,
|
||||
/*dim=*/zero, insertStart, insertEnd, /*step=*/one);
|
||||
} // start
|
||||
}
|
||||
if (is_2d) {
|
||||
Value batchPlusOne = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), batchValue, one);
|
||||
outputForBatch = rewriter.create<Torch::AtenUnsqueezeOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, resultShape[1]},
|
||||
resultType.getDtype()),
|
||||
outputForBatch, zero);
|
||||
output = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||
binder.getLoc(), resultType, output, outputForBatch,
|
||||
/*dim=*/zero, batchValue, batchPlusOne, /*step=*/one);
|
||||
} else {
|
||||
output = outputForBatch;
|
||||
}
|
||||
rewriter.create<Torch::PrimLoopConditionOp>(
|
||||
binder.getLoc(), loopConditionTrue, ValueRange({output}));
|
||||
}
|
||||
output = batchLoop.getResult(0);
|
||||
rewriter.replaceOp(binder.op, output);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Location loc = binder.getLoc();
|
||||
|
|
|
@ -1915,6 +1915,57 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
|
|||
return %0 : !torch.vtensor<[2],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL : func.func @test_tfidfvectorizer_tf_batch_only_bigrams_skip5
|
||||
func.func @test_tfidfvectorizer_tf_batch_onlybigrams_skip5(%arg0: !torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK : %[[output_init:.*]] = torch.aten.zeros %[[x0:.*]], %[[none_0:.*]], %[[none_0]], %[[none_0]], %[[none_0]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,7],f32>
|
||||
// CHECK : %[[int2_1:.*]] = torch.constant.int 2
|
||||
// CHECK : %[[batch_loop:.*]] = torch.prim.Loop %[[int2_1]], %[[true:.*]], init(%[[output_init]]) {
|
||||
// CHECK : ^bb0(%[[arg1:.*]]: !torch.int, %[[arg2:.*]]: !torch.vtensor<[2,7],f32>):
|
||||
// CHECK : %[[x3:.*]] = torch.aten.add.int %[[arg1]], %[[int1:.*]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[x4:.*]] = torch.aten.slice.Tensor %arg0, %[[int0:.*]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,6],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,6],si32>
|
||||
// CHECK : %[[inputbatch:.*]] = torch.aten.squeeze.dim %[[x4]], %[[int0]] : !torch.vtensor<[1,6],si32>, !torch.int -> !torch.vtensor<[6],si32>
|
||||
// CHECK : %[[x6:.*]] = torch.aten.slice.Tensor %[[arg2]], %[[int0]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,7],f32>
|
||||
// CHECK : %[[outputbatch:.*]] = torch.aten.squeeze.dim %[[x6]], %[[int0]] : !torch.vtensor<[1,7],f32>, !torch.int -> !torch.vtensor<[7],f32>
|
||||
// CHECK : %[[int2_2:.*]] = torch.constant.int 2
|
||||
// CHECK : %[[int0_3:.*]] = torch.constant.int 0
|
||||
// CHECK : %[[max_skip_count:.*]] = torch.constant.int 6
|
||||
// CHECK : %[[skip_loop:.*]] = torch.prim.Loop %[[max_skip_count]], %[[true]], init(%[[int0_3]]) {
|
||||
// CHECK : ^bb0(%[[arg3:.*]]: !torch.int, %[[arg4:.*]]: !torch.int):
|
||||
// CHECK : %[[x29:.*]] = torch.aten.add.int %[[arg3]], %[[int1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[int6_12:.*]] = torch.constant.int 6
|
||||
// CHECK : %[[x30:.*]] = torch.aten.sub.int %[[int2_2]], %[[int1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[x31:.*]] = torch.aten.mul.int %[[x30]], %[[x29]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[x32:.*]] = torch.aten.sub.int %[[int6_12]], %[[x31]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[count_loop:.*]] = torch.prim.Loop %[[x32]], %[[true]], init(%[[arg4]]) {
|
||||
// CHECK : ^bb0(%[[arg5:.*]]: !torch.int, %[[arg6:.*]]: !torch.int):
|
||||
// CHECK : %[[input_2gram0:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position0:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
||||
// CHECK : %[[inputval0:.*]] = torch.aten.item %[[input_2gram0]] : !torch.vtensor<[1],si32> -> !torch.int
|
||||
// CHECK : %[[eq0:.*]] = torch.aten.eq.int %[[inputval0]], %[[first2gram0:.*]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK : %[[eq0int:.*]] = torch.aten.Int.bool %[[eq0]] : !torch.bool -> !torch.int
|
||||
// CHECK : %[[alleq0:.*]] = torch.aten.mul.int %[[eq0int]], %[[int1_13:.*]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[input_2gram1:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position1:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
||||
// CHECK : %[[inputval1:.*]] = torch.aten.item %[[input_2gram1]] : !torch.vtensor<[1],si32> -> !torch.int
|
||||
// CHECK : %[[eq1:.*]] = torch.aten.eq.int %[[inputval1]], %[[first2gram1:.*]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK : %[[eq1int:.*]] = torch.aten.Int.bool %[[eq1]] : !torch.bool -> !torch.int
|
||||
// CHECK : %[[alleq1:.*]] = torch.aten.mul.int %[[eq1int]], %[[alleq0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : %[[newcount:.*]] = torch.aten.add.int %[[arg6]], %[[alleq1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK : torch.prim.Loop.condition %[[true]], iter(%[[newcount]] : !torch.int)
|
||||
// CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int
|
||||
// CHECK : torch.prim.Loop.condition %[[true]], iter(%[[skip_loop]] : !torch.int)
|
||||
// CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int
|
||||
// CHECK : %[[count_insert0:.*]] = torch.aten.slice_scatter %[[outputbatch]], %[[counttensor0:.*]], %[[int0]], %[[ngram_indices0:.*]], %[[ngram_indices0plus1:.*]], %[[int1]] : !torch.vtensor<[7],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[7],f32>
|
||||
// the skip_loop and count_loops repeat for each ngram in the pool_int64t's, then after the last ngram frequency is counted...
|
||||
// CHECK : %[[unqueezecounts:.*]] = torch.aten.unsqueeze % [[lastcountinsert:.*]], %[[int0]] : !torch.vtensor<[7],f32>, !torch.int -> !torch.vtensor<[1,7],f32>
|
||||
// CHECK : %[[count_into_output:.*]] = torch.aten.slice_scatter %[[arg2]], %[[unsqueezecounts]], %[[int0]], %[[arg1]], %[[arg1plus1:.*]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.vtensor<[1,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,7],f32>
|
||||
// CHECK : torch.prim.Loop.condition %[[true]], iter(%[[count_into_output]] : !torch.vtensor<[2,7],f32>)
|
||||
// CHECK : } : (!torch.int, !torch.bool, !torch.vtensor<[2,7],f32>) -> !torch.vtensor<[2,7],f32>
|
||||
// CHECK : return %[[batchloop]] : !torch.vtensor<[2,7],f32>
|
||||
%0 = torch.operator "onnx.TfIdfVectorizer"(%arg0) {torch.onnx.max_gram_length = 2 : si64, torch.onnx.max_skip_count = 5 : si64, torch.onnx.min_gram_length = 2 : si64, torch.onnx.mode = "TF", torch.onnx.ngram_counts = [0 : si64, 4 : si64], torch.onnx.ngram_indexes = [0 : si64, 1 : si64, 2 : si64, 3 : si64, 4 : si64, 5 : si64, 6 : si64], torch.onnx.pool_int64s = [2 : si64, 3 : si64, 5 : si64, 4 : si64, 5 : si64, 6 : si64, 7 : si64, 8 : si64, 6 : si64, 7 : si64]} : (!torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32>
|
||||
return %0 : !torch.vtensor<[2,7],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_range_int16_type
|
||||
|
|
Loading…
Reference in New Issue