[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)

Address https://github.com/nod-ai/SHARK-Turbine/issues/833
pull/3756/head
Samu Tamminen 2024-10-02 15:17:58 +02:00 committed by GitHub
parent 617c1c76ce
commit a2bfe47faa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 4 deletions

View File

@ -338,6 +338,31 @@ struct OpBinder {
return failure(); return failure();
} }
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
StringRef nameSuffix,
ArrayRef<float> defaults) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
values.append(defaults.begin(), defaults.end());
return success();
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
auto floatAttr = dyn_cast<FloatAttr>(element);
if (!floatAttr)
return failure();
FloatType t = cast<FloatType>(floatAttr.getType());
if (t.getWidth() != 32)
return failure();
values.push_back(floatAttr.getValue().convertToFloat());
}
return success();
}
return failure();
}
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values, ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
StringRef nameSuffix) { StringRef nameSuffix) {
SmallString<64> name("torch.onnx."); SmallString<64> name("torch.onnx.");

View File

@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> ngram_counts; llvm::SmallVector<int64_t> ngram_counts;
llvm::SmallVector<int64_t> ngram_indexes; llvm::SmallVector<int64_t> ngram_indexes;
llvm::SmallVector<int64_t> pool_int64s; llvm::SmallVector<int64_t> pool_int64s;
llvm::SmallVector<float> weights;
std::string mode; std::string mode;
int64_t min_gram_length; int64_t min_gram_length;
int64_t max_gram_length; int64_t max_gram_length;
@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorOperand(input) || binder.tensorResultType(resultType)) binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure(); return failure();
if (mode != "TF") llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
return rewriter.notifyMatchFailure(binder.op, if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
"TF mode supported only"); return failure();
if (pool_int64s.size() == 0) if (pool_int64s.size() == 0)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "pool_int64s empty, only integers supported"); binder.op, "pool_int64s empty, only integers supported");
@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), loopConditionTrue, ValueRange({count})); binder.getLoc(), loopConditionTrue, ValueRange({count}));
} }
count = skipLoop.getResult(0); count = skipLoop.getResult(0);
// insert count "tf" into output
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>( Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
binder.getLoc(), count); binder.getLoc(), count);
if (mode == "IDF" || mode == "TFIDF") {
// both IDF and TFIDF modes use weights
float weight = weights[ngram_i];
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(weight));
// TFIDF
Value multiplier = countFloat;
if (mode == "IDF") {
// All the counts larger than 1 would be truncated to 1
// and the i-th element in weights would be used to scale
// (by multiplication) the count of the i-th n-gram in pool.
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
binder.getLoc(), count);
// compare intCount > 0
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
binder.getLoc(), intCount, zero);
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), gtZeroCount);
Value gtZeroCountFloat =
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
gtZeroCount);
multiplier = gtZeroCountFloat;
}
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), multiplier, constWeight);
}
Value dataList = rewriter.create<Torch::PrimListConstructOp>( Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), binder.getLoc(),
rewriter.getType<Torch::ListType>( rewriter.getType<Torch::ListType>(