pull/3839/head
yyp0 2024-10-31 12:13:48 +08:00
parent 21005a7d63
commit 92e0098a00
1 changed files with 6 additions and 0 deletions

View File

@ -8896,6 +8896,12 @@ public:
op, "Unimplemented: unranked target tensor");
unsigned targetRank = maybeRank.value();
Value reduction = op.getReduction();
int64_t reductionInt;
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
return rewriter.notifyMatchFailure(op,
"reduction should be a constant int!");
}
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
// of the form [minibatch] the cross entropy loss decomposes to the
// combination of softmax and nll loss as follows: