diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index d509bf037..fd81ad5fd 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -218,11 +218,23 @@ public: Value weight = adaptor.weight(); Value indices = adaptor.indices(); Value offsets = adaptor.offsets(); - Value scaleGradByFreq = adaptor.scale_grad_by_freq(); + Value scaleGradByFreq = op.scale_grad_by_freq(); Value mode = op.mode(); Value sparse = op.sparse(); Value includeLastOffset = op.include_last_offset(); + bool scaleGradByFreqBool; + if (!matchPattern(scaleGradByFreq, + m_TorchConstantBool(&scaleGradByFreqBool))) { + return rewriter.notifyMatchFailure( + op, "scale_grad_by_freq is expected to be a constant boolean value."); + } + + if (scaleGradByFreqBool) { + return rewriter.notifyMatchFailure( + op, "Unimplemented: scale_grad_by_freq=True."); + } + int64_t modeInt; if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) { return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4a0ff065f..60cf7c5e7 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2476,8 +2476,6 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op, PatternRewriter &rewriter) const override { - - Location loc = op.getLoc(); Value weight = op.weight(); Value indices = op.indices(); Value offsets = op.offsets();