[MLIR][TORCH] Add support for the total_weight for aten.nll_loss_forward op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2188/head
Vivek Khandelwal 2023-05-23 18:14:57 +05:30
parent 552887783a
commit 959f4f48d5
2 changed files with 64 additions and 11 deletions

View File

@ -1258,13 +1258,14 @@ public:
b.create<linalg::YieldOp>(loc, selectFinal);
});
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
/*inclusive=*/false);
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
if (reduction == torch_upstream::Reduction::Sum ||
reduction == torch_upstream::Reduction::Mean) {
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
/*inclusive=*/false);
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};
finalRes = torch_to_linalg::createReductionLinalgGeneric(
@ -1280,9 +1281,61 @@ public:
});
}
// TODO: Update the second result tensor.
Value weightUpdated = createZeroInitTensor(rewriter, loc, {}, elementType);
rewriter.replaceOp(op, {finalRes, weightUpdated});
// The implementation for the `total_weight` has been adopted from here:
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L154-L294
// As per the ref link, the `total_weight` value when the `weight` is
// `None`, is equal to `total_weight = batch_size - num_ignored_index`,
// where `batch_size` is equal to `target.shape[0]` when rank(target) > 0,
// otherwise 1. The value `num_ignored_index` is the number of elements of
// the `target` tensors that have been ignored.
if (reduction == torch_upstream::Reduction::None && inputRank == 2) {
Value totalWeight = createZeroInitTensor(rewriter, loc, {}, elementType);
rewriter.replaceOp(op, {finalRes, totalWeight});
return success();
}
Value numIgnoredIndex;
if (targetRank == 0) {
Value targetVal = rewriter.create<tensor::ExtractOp>(loc, target);
numIgnoredIndex = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex);
numIgnoredIndex = convertScalarToDtype(rewriter, loc, numIgnoredIndex,
ignoreIndex.getType());
} else {
Value zeroCstInt = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(ignoreIndex.getType()));
auto opInfo =
torch_to_linalg::ReductionOpInfo{/*keepDim=*/false, target, dimSet};
numIgnoredIndex = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, opInfo,
/*initElem=*/zeroCstInt,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value targetVal = args[0];
Value accumulator = args[1];
Value result = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex);
result = b.create<arith::AddIOp>(
loc,
convertScalarToDtype(rewriter, loc, result,
ignoreIndex.getType()),
accumulator);
b.create<linalg::YieldOp>(loc, result);
});
numIgnoredIndex =
rewriter.create<tensor::ExtractOp>(loc, numIgnoredIndex);
}
Value numtargetElems = getTensorSize(rewriter, loc, target);
Value totalWeightVal =
rewriter.create<arith::SubIOp>(loc, numtargetElems, numIgnoredIndex);
Value totalWeight = createInitTensor(
rewriter, loc, {}, elementType,
convertScalarToDtype(rewriter, loc, totalWeightVal, elementType));
rewriter.replaceOp(op, {finalRes, totalWeight});
return success();
}
};

View File

@ -29,7 +29,7 @@ class NllLossModule(torch.nn.Module):
target=y,
weight=None,
reduction=0,
ignore_index=2)[0]
ignore_index=2)
@register_test_case(module_factory=lambda: NllLossModule())
@ -53,7 +53,7 @@ class NllLossModule_mean(torch.nn.Module):
target=y,
weight=None,
reduction=1,
ignore_index=2)[0]
ignore_index=2)
@register_test_case(module_factory=lambda: NllLossModule_mean())
@ -77,7 +77,7 @@ class NllLossModule_sum(torch.nn.Module):
target=y,
weight=None,
reduction=2,
ignore_index=2)[0]
ignore_index=2)
@register_test_case(module_factory=lambda: NllLossModule_sum())
@ -101,7 +101,7 @@ class NllLossModule_1D(torch.nn.Module):
target=y,
weight=None,
reduction=0,
ignore_index=2)[0]
ignore_index=2)
@register_test_case(module_factory=lambda: NllLossModule_1D())
@ -126,7 +126,7 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
target=y,
weight=None,
reduction=0,
ignore_index=10)[0]
ignore_index=10)
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())