[ONNX] add support for tfidfvectorizer (#3553)

1-d/2-d input and output
implemented based on the description and example test cases in
https://github.com/onnx/onnx/blob/main/docs/Operators.md#TfIdfVectorizer
and some notes from

https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_tfidf_vectorizer.py#L128

---------

Co-authored-by: zjgarvey <zjgarvey@gmail.com>
pull/3629/head
aldesilv 2024-08-12 16:10:11 -07:00 committed by GitHub
parent d3695a97a0
commit a4ba02eef5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 353 additions and 0 deletions

View File

@ -4305,6 +4305,308 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
uniqueResults[1], uniqueResults[2]});
return success();
});
patterns.onOp(
"TfIdfVectorizer", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
llvm::SmallVector<int64_t> ngram_counts;
llvm::SmallVector<int64_t> ngram_indexes;
llvm::SmallVector<int64_t> pool_int64s;
std::string mode;
int64_t min_gram_length;
int64_t max_gram_length;
int64_t max_skip_count;
Value input;
Torch::ValueTensorType resultType;
if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) ||
binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) ||
binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) ||
binder.customOpNameStringAttr(mode, "mode", "") ||
binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) ||
binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) ||
binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) ||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure();
if (mode != "TF")
return rewriter.notifyMatchFailure(binder.op,
"TF mode supported only");
if (pool_int64s.size() == 0)
return rewriter.notifyMatchFailure(
binder.op, "pool_int64s empty, only integers supported");
auto inputType = dyn_cast<Torch::ValueTensorType>(input.getType());
auto inputSizes =
dyn_cast<Torch::ValueTensorType>(input.getType()).getSizes();
SmallVector<int64_t> inputShape(inputSizes);
bool is_2d = (inputShape.size() > 1) ? true : false;
if (is_2d && inputShape[0] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
binder.op, "input batch dimension cannot be dynamic");
int batch_size = (is_2d) ? inputShape[0] : 1;
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(false));
auto intType = rewriter.getType<Torch::IntType>();
Value loopConditionTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
Type loopIndexType = intType;
// create a zero tensor for output
SmallVector<int64_t> resultShape(resultType.getSizes());
int64_t rank = resultShape.size();
SmallVector<Value> zerosShapeValues;
for (int j = 0; j < rank; j++) {
Value dimSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(resultShape[j]));
zerosShapeValues.push_back(dimSize);
}
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
Value output = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), resultType, zerosShapeList, none, none, none,
none);
Value batchSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(batch_size));
auto batchLoop = rewriter.create<Torch::PrimLoopOp>(
binder.getLoc(), TypeRange({output.getType()}), batchSize,
loopConditionTrue, ValueRange({output}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *batchLoopBody = rewriter.createBlock(
&batchLoop.getRegion(), batchLoop.getRegion().begin(),
TypeRange({loopIndexType, output.getType()}),
{binder.getLoc(), binder.getLoc()});
Value batchValue = batchLoopBody->getArgument(0);
Value output = batchLoopBody->getArgument(1);
Value outputForBatch = output;
Value inputSequence = input;
if (is_2d) {
// get input sequence from input (ex: [[0,1],[2,3]] -> [[0,1]] ->
// [0,1])
SmallVector<int64_t> inputSequenceShape;
inputSequenceShape.push_back(1);
inputSequenceShape.push_back(inputShape[1]);
auto inputSequenceType = rewriter.getType<Torch::ValueTensorType>(
inputSequenceShape, inputType.getOptionalDtype());
Value batchPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), batchValue, one);
inputSequence = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), inputSequenceType, input, /*dim=*/zero,
batchValue, batchPlusOne, one);
inputSequence = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{inputShape[1]},
inputType.getOptionalDtype()),
inputSequence, zero);
SmallVector<int64_t> outputForBatchShape;
outputForBatchShape.push_back(1);
outputForBatchShape.push_back(resultShape[1]);
auto outputForBatchType = rewriter.getType<Torch::ValueTensorType>(
outputForBatchShape, resultType.getOptionalDtype());
outputForBatch = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), outputForBatchType, output,
/*dim=*/zero, batchValue, batchPlusOne, one);
outputForBatch = rewriter.create<Torch::AtenSqueezeDimOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(),
ArrayRef<int64_t>{resultShape[1]},
resultType.getOptionalDtype()),
outputForBatch, zero);
}
// ngram_counts[j] records the starting position of ngrams within the
// pool_int64's of length j+1. The loop below is iterating through the
// different n-gram sizes
// ngram_i keeps track of which ngram we are looking at in the pool.
// The frequency of this ngram will be stored in the output tensor at
// the position ngram_indexes[ngram_i]
int ngram_i = 0;
for (int j = 0; j < (int)ngram_counts.size(); j++) {
int ngram_length = j + 1;
int start_idx = ngram_counts[j];
int end_idx = (j + 1) < (int)ngram_counts.size()
? ngram_counts[j + 1]
: pool_int64s.size();
if (j + 1 < min_gram_length || j + 1 > max_gram_length) {
// progress the ngram counter for the skipped (j+1)grams
ngram_i += (end_idx - start_idx) / ngram_length;
continue;
}
Value ngramLength = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(ngram_length));
for (int start = start_idx; start < end_idx;
start += ngram_length, ngram_i++) {
Value count = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
// for 1-grams, there is no skipping (skip = gap between
// consecutive values in the n-gram pulled from the input
// sequence), so we default to skip_count_bound = 1 in that case
// to avoid repeating the same count multiple times.
int skip_count_bound =
(ngram_length == 1) ? 1 : (max_skip_count + 1);
Value skipCountBound = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), intType,
rewriter.getI64IntegerAttr(skip_count_bound));
// given a n-gram to search for, and the input sequence to search
// in, we need to count how many times that n-gram appears in the
// input for each skip between 0 and max_skip_count (inclusive).
auto skipLoop = rewriter.create<Torch::PrimLoopOp>(
binder.getLoc(), TypeRange({count.getType()}), skipCountBound,
loopConditionTrue, ValueRange({count}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *skipLoopBody = rewriter.createBlock(
&skipLoop.getRegion(), skipLoop.getRegion().begin(),
TypeRange({loopIndexType, count.getType()}),
{binder.getLoc(), binder.getLoc()});
Value skipCount = skipLoopBody->getArgument(0);
Value skipCountPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), skipCount, one);
count = skipLoopBody->getArgument(1);
// max_start_index =
// inputSizes.back() - ((ngram_length - 1) * (skip_count + 1));
// the index one higher than the last possible start index
// without the input ngram going out of bounds
Value seqLen = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), intType,
rewriter.getI64IntegerAttr(inputSizes.back()));
Value ngramLengthMinusOne =
rewriter.create<Torch::AtenSubIntOp>(binder.getLoc(),
ngramLength, one);
Value ngramSkipLength = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), ngramLengthMinusOne, skipCountPlusOne);
Value maxStartIndex = rewriter.create<Torch::AtenSubIntOp>(
binder.getLoc(), seqLen, ngramSkipLength);
// This loop will extract each n-gram with the given skip_count
// from the input sequence from start input index, and increment
// the count if the n-gram matches the one gotten from the
// pool_int64s
auto countLoop = rewriter.create<Torch::PrimLoopOp>(
binder.getLoc(), TypeRange({count.getType()}),
maxStartIndex, loopConditionTrue, ValueRange({count}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *countLoopBody = rewriter.createBlock(
&countLoop.getRegion(), countLoop.getRegion().begin(),
TypeRange({loopIndexType, count.getType()}),
{binder.getLoc(), binder.getLoc()});
Value startInputIdx = countLoopBody->getArgument(0);
count = countLoopBody->getArgument(1);
// extract input ngram and compare to pool ngram
Torch::BaseTensorType inputSequenceType =
cast<Torch::BaseTensorType>(inputSequence.getType());
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType =
inputSequenceType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes),
inputSequenceType.getOptionalDtype());
Value foundNgram = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
for (int i = 0; i < ngram_length; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
i));
selectIndex = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), selectIndex, skipCountPlusOne);
selectIndex = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), selectIndex, startInputIdx);
Value inputExtract =
rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, inputSequence,
zero, selectIndex);
Value inputNgram_i = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
inputExtract);
Value poolNgram_i = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(pool_int64s[start + i]));
Value isEqual = rewriter.create<Torch::AtenEqIntOp>(
binder.getLoc(), inputNgram_i, poolNgram_i);
isEqual = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), isEqual);
foundNgram = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isEqual, foundNgram);
}
count = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), count, foundNgram);
rewriter.create<Torch::PrimLoopConditionOp>(
binder.getLoc(), loopConditionTrue, ValueRange({count}));
}
count = countLoop.getResult(0);
rewriter.create<Torch::PrimLoopConditionOp>(
binder.getLoc(), loopConditionTrue, ValueRange({count}));
}
count = skipLoop.getResult(0);
// insert count "tf" into output
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
binder.getLoc(), count);
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::FloatType>()),
SmallVector<Value>{countFloat});
Value cstDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
(int)torch_upstream::ScalarType::Float));
SmallVector<int64_t> countShape{1};
auto countType = rewriter.getType<Torch::ValueTensorType>(
countShape, resultType.getOptionalDtype());
Value countTensor = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(), countType, dataList, /*dtype=*/cstDtype,
/*layout=*/none, /*requires_grad=*/cstFalse);
Value insertStart = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(ngram_indexes[ngram_i]));
Value insertEnd = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), insertStart, one);
outputForBatch = rewriter.create<Torch::AtenSliceScatterOp>(
binder.getLoc(), outputForBatch.getType(), outputForBatch,
countTensor,
/*dim=*/zero, insertStart, insertEnd, /*step=*/one);
} // start
}
if (is_2d) {
Value batchPlusOne = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), batchValue, one);
outputForBatch = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(),
rewriter.getType<Torch::ValueTensorType>(
llvm::SmallVector<int64_t>{1, resultShape[1]},
resultType.getDtype()),
outputForBatch, zero);
output = rewriter.create<Torch::AtenSliceScatterOp>(
binder.getLoc(), resultType, output, outputForBatch,
/*dim=*/zero, batchValue, batchPlusOne, /*step=*/one);
} else {
output = outputForBatch;
}
rewriter.create<Torch::PrimLoopConditionOp>(
binder.getLoc(), loopConditionTrue, ValueRange({output}));
}
output = batchLoop.getResult(0);
rewriter.replaceOp(binder.op, output);
return success();
});
patterns.onOp(
"Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();

View File

@ -1915,6 +1915,57 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>
return %0 : !torch.vtensor<[2],si32>
}
// -----
// CHECK-LABEL : func.func @test_tfidfvectorizer_tf_batch_only_bigrams_skip5
func.func @test_tfidfvectorizer_tf_batch_onlybigrams_skip5(%arg0: !torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK : %[[output_init:.*]] = torch.aten.zeros %[[x0:.*]], %[[none_0:.*]], %[[none_0]], %[[none_0]], %[[none_0]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,7],f32>
// CHECK : %[[int2_1:.*]] = torch.constant.int 2
// CHECK : %[[batch_loop:.*]] = torch.prim.Loop %[[int2_1]], %[[true:.*]], init(%[[output_init]]) {
// CHECK : ^bb0(%[[arg1:.*]]: !torch.int, %[[arg2:.*]]: !torch.vtensor<[2,7],f32>):
// CHECK : %[[x3:.*]] = torch.aten.add.int %[[arg1]], %[[int1:.*]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[x4:.*]] = torch.aten.slice.Tensor %arg0, %[[int0:.*]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,6],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,6],si32>
// CHECK : %[[inputbatch:.*]] = torch.aten.squeeze.dim %[[x4]], %[[int0]] : !torch.vtensor<[1,6],si32>, !torch.int -> !torch.vtensor<[6],si32>
// CHECK : %[[x6:.*]] = torch.aten.slice.Tensor %[[arg2]], %[[int0]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,7],f32>
// CHECK : %[[outputbatch:.*]] = torch.aten.squeeze.dim %[[x6]], %[[int0]] : !torch.vtensor<[1,7],f32>, !torch.int -> !torch.vtensor<[7],f32>
// CHECK : %[[int2_2:.*]] = torch.constant.int 2
// CHECK : %[[int0_3:.*]] = torch.constant.int 0
// CHECK : %[[max_skip_count:.*]] = torch.constant.int 6
// CHECK : %[[skip_loop:.*]] = torch.prim.Loop %[[max_skip_count]], %[[true]], init(%[[int0_3]]) {
// CHECK : ^bb0(%[[arg3:.*]]: !torch.int, %[[arg4:.*]]: !torch.int):
// CHECK : %[[x29:.*]] = torch.aten.add.int %[[arg3]], %[[int1]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[int6_12:.*]] = torch.constant.int 6
// CHECK : %[[x30:.*]] = torch.aten.sub.int %[[int2_2]], %[[int1]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[x31:.*]] = torch.aten.mul.int %[[x30]], %[[x29]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[x32:.*]] = torch.aten.sub.int %[[int6_12]], %[[x31]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[count_loop:.*]] = torch.prim.Loop %[[x32]], %[[true]], init(%[[arg4]]) {
// CHECK : ^bb0(%[[arg5:.*]]: !torch.int, %[[arg6:.*]]: !torch.int):
// CHECK : %[[input_2gram0:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position0:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK : %[[inputval0:.*]] = torch.aten.item %[[input_2gram0]] : !torch.vtensor<[1],si32> -> !torch.int
// CHECK : %[[eq0:.*]] = torch.aten.eq.int %[[inputval0]], %[[first2gram0:.*]] : !torch.int, !torch.int -> !torch.bool
// CHECK : %[[eq0int:.*]] = torch.aten.Int.bool %[[eq0]] : !torch.bool -> !torch.int
// CHECK : %[[alleq0:.*]] = torch.aten.mul.int %[[eq0int]], %[[int1_13:.*]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[input_2gram1:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position1:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK : %[[inputval1:.*]] = torch.aten.item %[[input_2gram1]] : !torch.vtensor<[1],si32> -> !torch.int
// CHECK : %[[eq1:.*]] = torch.aten.eq.int %[[inputval1]], %[[first2gram1:.*]] : !torch.int, !torch.int -> !torch.bool
// CHECK : %[[eq1int:.*]] = torch.aten.Int.bool %[[eq1]] : !torch.bool -> !torch.int
// CHECK : %[[alleq1:.*]] = torch.aten.mul.int %[[eq1int]], %[[alleq0]] : !torch.int, !torch.int -> !torch.int
// CHECK : %[[newcount:.*]] = torch.aten.add.int %[[arg6]], %[[alleq1]] : !torch.int, !torch.int -> !torch.int
// CHECK : torch.prim.Loop.condition %[[true]], iter(%[[newcount]] : !torch.int)
// CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int
// CHECK : torch.prim.Loop.condition %[[true]], iter(%[[skip_loop]] : !torch.int)
// CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int
// CHECK : %[[count_insert0:.*]] = torch.aten.slice_scatter %[[outputbatch]], %[[counttensor0:.*]], %[[int0]], %[[ngram_indices0:.*]], %[[ngram_indices0plus1:.*]], %[[int1]] : !torch.vtensor<[7],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[7],f32>
// the skip_loop and count_loops repeat for each ngram in the pool_int64t's, then after the last ngram frequency is counted...
// CHECK : %[[unqueezecounts:.*]] = torch.aten.unsqueeze % [[lastcountinsert:.*]], %[[int0]] : !torch.vtensor<[7],f32>, !torch.int -> !torch.vtensor<[1,7],f32>
// CHECK : %[[count_into_output:.*]] = torch.aten.slice_scatter %[[arg2]], %[[unsqueezecounts]], %[[int0]], %[[arg1]], %[[arg1plus1:.*]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.vtensor<[1,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,7],f32>
// CHECK : torch.prim.Loop.condition %[[true]], iter(%[[count_into_output]] : !torch.vtensor<[2,7],f32>)
// CHECK : } : (!torch.int, !torch.bool, !torch.vtensor<[2,7],f32>) -> !torch.vtensor<[2,7],f32>
// CHECK : return %[[batchloop]] : !torch.vtensor<[2,7],f32>
%0 = torch.operator "onnx.TfIdfVectorizer"(%arg0) {torch.onnx.max_gram_length = 2 : si64, torch.onnx.max_skip_count = 5 : si64, torch.onnx.min_gram_length = 2 : si64, torch.onnx.mode = "TF", torch.onnx.ngram_counts = [0 : si64, 4 : si64], torch.onnx.ngram_indexes = [0 : si64, 1 : si64, 2 : si64, 3 : si64, 4 : si64, 5 : si64, 6 : si64], torch.onnx.pool_int64s = [2 : si64, 3 : si64, 5 : si64, 4 : si64, 5 : si64, 6 : si64, 7 : si64, 8 : si64, 6 : si64, 7 : si64]} : (!torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32>
return %0 : !torch.vtensor<[2,7],f32>
}
// -----
// CHECK-LABEL: func.func @test_range_int16_type