diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 1cf4df932..f71deaff2 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -338,6 +338,31 @@ struct OpBinder { return failure(); } + ParseResult f32FloatArrayAttr(llvm::SmallVector &values, + StringRef nameSuffix, + ArrayRef defaults) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + values.append(defaults.begin(), defaults.end()); + return success(); + } + if (auto arrayAttr = dyn_cast(attr)) { + for (auto element : arrayAttr) { + auto floatAttr = dyn_cast(element); + if (!floatAttr) + return failure(); + FloatType t = cast(floatAttr.getType()); + if (t.getWidth() != 32) + return failure(); + values.push_back(floatAttr.getValue().convertToFloat()); + } + return success(); + } + return failure(); + } + ParseResult stringArrayAttr(llvm::SmallVector &values, StringRef nameSuffix) { SmallString<64> name("torch.onnx."); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ea5156a0c..95413b080 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector ngram_counts; llvm::SmallVector ngram_indexes; llvm::SmallVector pool_int64s; + llvm::SmallVector weights; std::string mode; int64_t min_gram_length; int64_t max_gram_length; @@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorOperand(input) || binder.tensorResultType(resultType)) return failure(); - if (mode != "TF") - return rewriter.notifyMatchFailure(binder.op, - "TF mode supported only"); + llvm::SmallVector defaultWeights(ngram_indexes.size(), 1.0f); + if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights)) + return failure(); + if (pool_int64s.size() == 0) return rewriter.notifyMatchFailure( binder.op, "pool_int64s empty, only integers supported"); @@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), loopConditionTrue, ValueRange({count})); } count = skipLoop.getResult(0); - // insert count "tf" into output Value countFloat = rewriter.create( binder.getLoc(), count); + if (mode == "IDF" || mode == "TFIDF") { + // both IDF and TFIDF modes use weights + float weight = weights[ngram_i]; + Value constWeight = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(weight)); + + // TFIDF + Value multiplier = countFloat; + if (mode == "IDF") { + // All the counts larger than 1 would be truncated to 1 + // and the i-th element in weights would be used to scale + // (by multiplication) the count of the i-th n-gram in pool. + + Value intCount = rewriter.create( + binder.getLoc(), count); + // compare intCount > 0 + Value gtZeroCount = rewriter.create( + binder.getLoc(), intCount, zero); + gtZeroCount = rewriter.create( + binder.getLoc(), gtZeroCount); + Value gtZeroCountFloat = + rewriter.create(binder.getLoc(), + gtZeroCount); + multiplier = gtZeroCountFloat; + } + countFloat = rewriter.create( + binder.getLoc(), multiplier, constWeight); + } Value dataList = rewriter.create( binder.getLoc(), rewriter.getType(