mirror of https://github.com/llvm/torch-mlir
[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)
Address https://github.com/nod-ai/SHARK-Turbine/issues/833pull/3756/head
parent
617c1c76ce
commit
a2bfe47faa
|
@ -338,6 +338,31 @@ struct OpBinder {
|
|||
return failure();
|
||||
}
|
||||
|
||||
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
|
||||
StringRef nameSuffix,
|
||||
ArrayRef<float> defaults) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
name.append(nameSuffix);
|
||||
auto attr = op->getAttr(name);
|
||||
if (!attr) {
|
||||
values.append(defaults.begin(), defaults.end());
|
||||
return success();
|
||||
}
|
||||
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||
for (auto element : arrayAttr) {
|
||||
auto floatAttr = dyn_cast<FloatAttr>(element);
|
||||
if (!floatAttr)
|
||||
return failure();
|
||||
FloatType t = cast<FloatType>(floatAttr.getType());
|
||||
if (t.getWidth() != 32)
|
||||
return failure();
|
||||
values.push_back(floatAttr.getValue().convertToFloat());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
|
||||
StringRef nameSuffix) {
|
||||
SmallString<64> name("torch.onnx.");
|
||||
|
|
|
@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
llvm::SmallVector<int64_t> ngram_counts;
|
||||
llvm::SmallVector<int64_t> ngram_indexes;
|
||||
llvm::SmallVector<int64_t> pool_int64s;
|
||||
llvm::SmallVector<float> weights;
|
||||
std::string mode;
|
||||
int64_t min_gram_length;
|
||||
int64_t max_gram_length;
|
||||
|
@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
if (mode != "TF")
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"TF mode supported only");
|
||||
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
|
||||
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
|
||||
return failure();
|
||||
|
||||
if (pool_int64s.size() == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "pool_int64s empty, only integers supported");
|
||||
|
@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
||||
}
|
||||
count = skipLoop.getResult(0);
|
||||
// insert count "tf" into output
|
||||
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
||||
binder.getLoc(), count);
|
||||
if (mode == "IDF" || mode == "TFIDF") {
|
||||
// both IDF and TFIDF modes use weights
|
||||
float weight = weights[ngram_i];
|
||||
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getF64FloatAttr(weight));
|
||||
|
||||
// TFIDF
|
||||
Value multiplier = countFloat;
|
||||
if (mode == "IDF") {
|
||||
// All the counts larger than 1 would be truncated to 1
|
||||
// and the i-th element in weights would be used to scale
|
||||
// (by multiplication) the count of the i-th n-gram in pool.
|
||||
|
||||
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
|
||||
binder.getLoc(), count);
|
||||
// compare intCount > 0
|
||||
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
|
||||
binder.getLoc(), intCount, zero);
|
||||
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
|
||||
binder.getLoc(), gtZeroCount);
|
||||
Value gtZeroCountFloat =
|
||||
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
|
||||
gtZeroCount);
|
||||
multiplier = gtZeroCountFloat;
|
||||
}
|
||||
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), multiplier, constWeight);
|
||||
}
|
||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
|
|
Loading…
Reference in New Issue