mirror of https://github.com/llvm/torch-mlir
[onnx][torch] Fix `onnx.SoftmaxCrossEntropyLoss` for ignore index (#3585)
There were two issues related to `ignore_index` being set (1) the onnx-to-linalg pass as not reading the value correctly (2) the mean pass was not considering the `ignore_index` value For (2) when taking the mean we need to know how many of the values were considered in the sum and therefore we cannot divide by the total number of elements. Adding a summation across the total number should correct this issue.pull/3587/head
parent
22cd4441e7
commit
306ed62edd
|
@ -2772,7 +2772,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value scores, labels, weight;
|
||||
if (binder.tensorOperandAtIndex(scores, 0) ||
|
||||
binder.tensorOperandAtIndex(labels, 1) ||
|
||||
binder.s64IntegerAttr(ignoreIndex, "ignore_index ", -100) ||
|
||||
binder.s64IntegerAttr(ignoreIndex, "ignore_index", -100) ||
|
||||
binder.customOpNameStringAttr(reduction, "reduction", "mean") ||
|
||||
binder.tensorResultTypeAtIndex(resultType, 0)) {
|
||||
return failure();
|
||||
|
|
|
@ -1649,7 +1649,27 @@ public:
|
|||
|
||||
if (reduction == torch_upstream::Reduction::Sum ||
|
||||
reduction == torch_upstream::Reduction::Mean) {
|
||||
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
|
||||
|
||||
Value zeroIVal = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(rewriter.getI32Type()));
|
||||
auto countInfo = torch_to_linalg::ReductionOpInfo{false, target, dimSet};
|
||||
Value numOfElems = torch_to_linalg::createReductionLinalgGeneric(
|
||||
rewriter, loc, countInfo,
|
||||
/*initElem=*/zeroIVal,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value targetVal = args[0];
|
||||
Value indTarget = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), targetVal);
|
||||
Value cmpEq = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, indTarget, ignoreIndexVal);
|
||||
cmpEq = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(),
|
||||
cmpEq);
|
||||
Value add = rewriter.create<arith::AddIOp>(loc, args[1], cmpEq);
|
||||
rewriter.create<linalg::YieldOp>(loc, add);
|
||||
});
|
||||
|
||||
numOfElems = rewriter.create<tensor::ExtractOp>(
|
||||
loc, rewriter.getI32Type(), numOfElems, ArrayRef<Value>{});
|
||||
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
|
||||
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};
|
||||
|
|
Loading…
Reference in New Issue