diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index fdb9e5fec..7a8718846 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8896,6 +8896,12 @@ public: op, "Unimplemented: unranked target tensor"); unsigned targetRank = maybeRank.value(); + Value reduction = op.getReduction(); + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d // of the form [minibatch] the cross entropy loss decomposes to the // combination of softmax and nll loss as follows: