mirror of https://github.com/llvm/torch-mlir
[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)
Address https://github.com/nod-ai/SHARK-Turbine/issues/833pull/3756/head
parent
617c1c76ce
commit
a2bfe47faa
|
@ -338,6 +338,31 @@ struct OpBinder {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
|
||||||
|
StringRef nameSuffix,
|
||||||
|
ArrayRef<float> defaults) {
|
||||||
|
SmallString<64> name("torch.onnx.");
|
||||||
|
name.append(nameSuffix);
|
||||||
|
auto attr = op->getAttr(name);
|
||||||
|
if (!attr) {
|
||||||
|
values.append(defaults.begin(), defaults.end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||||
|
for (auto element : arrayAttr) {
|
||||||
|
auto floatAttr = dyn_cast<FloatAttr>(element);
|
||||||
|
if (!floatAttr)
|
||||||
|
return failure();
|
||||||
|
FloatType t = cast<FloatType>(floatAttr.getType());
|
||||||
|
if (t.getWidth() != 32)
|
||||||
|
return failure();
|
||||||
|
values.push_back(floatAttr.getValue().convertToFloat());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
|
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
|
||||||
StringRef nameSuffix) {
|
StringRef nameSuffix) {
|
||||||
SmallString<64> name("torch.onnx.");
|
SmallString<64> name("torch.onnx.");
|
||||||
|
|
|
@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
llvm::SmallVector<int64_t> ngram_counts;
|
llvm::SmallVector<int64_t> ngram_counts;
|
||||||
llvm::SmallVector<int64_t> ngram_indexes;
|
llvm::SmallVector<int64_t> ngram_indexes;
|
||||||
llvm::SmallVector<int64_t> pool_int64s;
|
llvm::SmallVector<int64_t> pool_int64s;
|
||||||
|
llvm::SmallVector<float> weights;
|
||||||
std::string mode;
|
std::string mode;
|
||||||
int64_t min_gram_length;
|
int64_t min_gram_length;
|
||||||
int64_t max_gram_length;
|
int64_t max_gram_length;
|
||||||
|
@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (mode != "TF")
|
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
|
||||||
"TF mode supported only");
|
return failure();
|
||||||
|
|
||||||
if (pool_int64s.size() == 0)
|
if (pool_int64s.size() == 0)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "pool_int64s empty, only integers supported");
|
binder.op, "pool_int64s empty, only integers supported");
|
||||||
|
@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
binder.getLoc(), loopConditionTrue, ValueRange({count}));
|
||||||
}
|
}
|
||||||
count = skipLoop.getResult(0);
|
count = skipLoop.getResult(0);
|
||||||
// insert count "tf" into output
|
|
||||||
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
|
||||||
binder.getLoc(), count);
|
binder.getLoc(), count);
|
||||||
|
if (mode == "IDF" || mode == "TFIDF") {
|
||||||
|
// both IDF and TFIDF modes use weights
|
||||||
|
float weight = weights[ngram_i];
|
||||||
|
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getF64FloatAttr(weight));
|
||||||
|
|
||||||
|
// TFIDF
|
||||||
|
Value multiplier = countFloat;
|
||||||
|
if (mode == "IDF") {
|
||||||
|
// All the counts larger than 1 would be truncated to 1
|
||||||
|
// and the i-th element in weights would be used to scale
|
||||||
|
// (by multiplication) the count of the i-th n-gram in pool.
|
||||||
|
|
||||||
|
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
|
||||||
|
binder.getLoc(), count);
|
||||||
|
// compare intCount > 0
|
||||||
|
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
|
||||||
|
binder.getLoc(), intCount, zero);
|
||||||
|
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
|
||||||
|
binder.getLoc(), gtZeroCount);
|
||||||
|
Value gtZeroCountFloat =
|
||||||
|
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
|
||||||
|
gtZeroCount);
|
||||||
|
multiplier = gtZeroCountFloat;
|
||||||
|
}
|
||||||
|
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
|
||||||
|
binder.getLoc(), multiplier, constWeight);
|
||||||
|
}
|
||||||
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
rewriter.getType<Torch::ListType>(
|
rewriter.getType<Torch::ListType>(
|
||||||
|
|
Loading…
Reference in New Issue