mirror of https://github.com/llvm/torch-mlir
update
parent
21005a7d63
commit
92e0098a00
|
@ -8896,6 +8896,12 @@ public:
|
|||
op, "Unimplemented: unranked target tensor");
|
||||
unsigned targetRank = maybeRank.value();
|
||||
|
||||
Value reduction = op.getReduction();
|
||||
int64_t reductionInt;
|
||||
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"reduction should be a constant int!");
|
||||
}
|
||||
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
||||
// of the form [minibatch] the cross entropy loss decomposes to the
|
||||
// combination of softmax and nll loss as follows:
|
||||
|
|
Loading…
Reference in New Issue