[onnx][torch] Fix `onnx.SoftmaxCrossEntropyLoss` for ignore index (#3585)

There were two issues related to `ignore_index` being set

(1) the onnx-to-linalg pass as not reading the value correctly (2) the
mean pass was not considering the `ignore_index` value

For (2) when taking the mean we need to know how many of the values were
considered in the sum and therefore we cannot divide by the total number
of elements. Adding a summation across the total number should correct
this issue.
pull/3587/head
Rob Suderman 2024-08-02 09:00:56 -07:00 committed by GitHub
parent 22cd4441e7
commit 306ed62edd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 2 deletions

View File

@ -1649,7 +1649,27 @@ public:
if (reduction == torch_upstream::Reduction::Sum ||
reduction == torch_upstream::Reduction::Mean) {
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
Value zeroIVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(rewriter.getI32Type()));
auto countInfo = torch_to_linalg::ReductionOpInfo{false, target, dimSet};
Value numOfElems = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, countInfo,
/*initElem=*/zeroIVal,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value targetVal = args[0];
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), targetVal);
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, indTarget, ignoreIndexVal);
cmpEq = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(),
cmpEq);
Value add = rewriter.create<arith::AddIOp>(loc, args[1], cmpEq);
rewriter.create<linalg::YieldOp>(loc, add);
});
numOfElems = rewriter.create<tensor::ExtractOp>(
loc, rewriter.getI32Type(), numOfElems, ArrayRef<Value>{});
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};