mirror of https://github.com/llvm/torch-mlir
Fix unused-variables warnings about EmbeddingBag ops (#1220)
According to the documentation for `torch.embedding_bag` (https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding_bag.html), the default value for `scale_grad_by_freq` is False.pull/1168/head
parent
c935795086
commit
9d6ee48661
|
@ -218,11 +218,23 @@ public:
|
|||
Value weight = adaptor.weight();
|
||||
Value indices = adaptor.indices();
|
||||
Value offsets = adaptor.offsets();
|
||||
Value scaleGradByFreq = adaptor.scale_grad_by_freq();
|
||||
Value scaleGradByFreq = op.scale_grad_by_freq();
|
||||
Value mode = op.mode();
|
||||
Value sparse = op.sparse();
|
||||
Value includeLastOffset = op.include_last_offset();
|
||||
|
||||
bool scaleGradByFreqBool;
|
||||
if (!matchPattern(scaleGradByFreq,
|
||||
m_TorchConstantBool(&scaleGradByFreqBool))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "scale_grad_by_freq is expected to be a constant boolean value.");
|
||||
}
|
||||
|
||||
if (scaleGradByFreqBool) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: scale_grad_by_freq=True.");
|
||||
}
|
||||
|
||||
int64_t modeInt;
|
||||
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -2476,8 +2476,6 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value weight = op.weight();
|
||||
Value indices = op.indices();
|
||||
Value offsets = op.offsets();
|
||||
|
|
Loading…
Reference in New Issue