From 306ed62eddd3b806386d6495fb248bf4a9849802 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 2 Aug 2024 09:00:56 -0700 Subject: [PATCH] [onnx][torch] Fix `onnx.SoftmaxCrossEntropyLoss` for ignore index (#3585) There were two issues related to `ignore_index` being set (1) the onnx-to-linalg pass as not reading the value correctly (2) the mean pass was not considering the `ignore_index` value For (2) when taking the mean we need to know how many of the values were considered in the sum and therefore we cannot divide by the total number of elements. Adding a summation across the total number should correct this issue. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- .../TorchToLinalg/Uncategorized.cpp | 22 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 9ef165e77..2fa57f18c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2772,7 +2772,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scores, labels, weight; if (binder.tensorOperandAtIndex(scores, 0) || binder.tensorOperandAtIndex(labels, 1) || - binder.s64IntegerAttr(ignoreIndex, "ignore_index ", -100) || + binder.s64IntegerAttr(ignoreIndex, "ignore_index", -100) || binder.customOpNameStringAttr(reduction, "reduction", "mean") || binder.tensorResultTypeAtIndex(resultType, 0)) { return failure(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 74f64f3d2..211b1045e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1649,7 +1649,27 @@ public: if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { - Value numOfElems = getTensorSize(rewriter, loc, finalRes); + + Value zeroIVal = rewriter.create( + loc, rewriter.getZeroAttr(rewriter.getI32Type())); + auto countInfo = torch_to_linalg::ReductionOpInfo{false, target, dimSet}; + Value numOfElems = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, countInfo, + /*initElem=*/zeroIVal, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value targetVal = args[0]; + Value indTarget = rewriter.create( + loc, rewriter.getIndexType(), targetVal); + Value cmpEq = rewriter.create( + loc, arith::CmpIPredicate::ne, indTarget, ignoreIndexVal); + cmpEq = rewriter.create(loc, rewriter.getI32Type(), + cmpEq); + Value add = rewriter.create(loc, args[1], cmpEq); + rewriter.create(loc, add); + }); + + numOfElems = rewriter.create( + loc, rewriter.getI32Type(), numOfElems, ArrayRef{}); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};