mirror of https://github.com/llvm/torch-mlir
update
parent
21005a7d63
commit
92e0098a00
|
@ -8896,6 +8896,12 @@ public:
|
||||||
op, "Unimplemented: unranked target tensor");
|
op, "Unimplemented: unranked target tensor");
|
||||||
unsigned targetRank = maybeRank.value();
|
unsigned targetRank = maybeRank.value();
|
||||||
|
|
||||||
|
Value reduction = op.getReduction();
|
||||||
|
int64_t reductionInt;
|
||||||
|
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"reduction should be a constant int!");
|
||||||
|
}
|
||||||
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
// When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d
|
||||||
// of the form [minibatch] the cross entropy loss decomposes to the
|
// of the form [minibatch] the cross entropy loss decomposes to the
|
||||||
// combination of softmax and nll loss as follows:
|
// combination of softmax and nll loss as follows:
|
||||||
|
|
Loading…
Reference in New Issue