Fix unused-variables warnings about EmbeddingBag ops (#1220)

According to the documentation for
`torch.embedding_bag` (https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding_bag.html),
the default value for `scale_grad_by_freq` is False.
pull/1168/head
Ramiro Leal-Cavazos 2022-08-15 09:43:55 -07:00 committed by GitHub
parent c935795086
commit 9d6ee48661
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -218,11 +218,23 @@ public:
Value weight = adaptor.weight();
Value indices = adaptor.indices();
Value offsets = adaptor.offsets();
Value scaleGradByFreq = adaptor.scale_grad_by_freq();
Value scaleGradByFreq = op.scale_grad_by_freq();
Value mode = op.mode();
Value sparse = op.sparse();
Value includeLastOffset = op.include_last_offset();
bool scaleGradByFreqBool;
if (!matchPattern(scaleGradByFreq,
m_TorchConstantBool(&scaleGradByFreqBool))) {
return rewriter.notifyMatchFailure(
op, "scale_grad_by_freq is expected to be a constant boolean value.");
}
if (scaleGradByFreqBool) {
return rewriter.notifyMatchFailure(
op, "Unimplemented: scale_grad_by_freq=True.");
}
int64_t modeInt;
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
return rewriter.notifyMatchFailure(

View File

@ -2476,8 +2476,6 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value weight = op.weight();
Value indices = op.indices();
Value offsets = op.offsets();