mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for the total_weight for aten.nll_loss_forward op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2188/head
parent
552887783a
commit
959f4f48d5
|
@ -1258,13 +1258,14 @@ public:
|
|||
b.create<linalg::YieldOp>(loc, selectFinal);
|
||||
});
|
||||
|
||||
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
|
||||
/*inclusive=*/false);
|
||||
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
|
||||
|
||||
if (reduction == torch_upstream::Reduction::Sum ||
|
||||
reduction == torch_upstream::Reduction::Mean) {
|
||||
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
|
||||
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
|
||||
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
|
||||
/*inclusive=*/false);
|
||||
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
|
||||
|
||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};
|
||||
finalRes = torch_to_linalg::createReductionLinalgGeneric(
|
||||
|
@ -1280,9 +1281,61 @@ public:
|
|||
});
|
||||
}
|
||||
|
||||
// TODO: Update the second result tensor.
|
||||
Value weightUpdated = createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||
rewriter.replaceOp(op, {finalRes, weightUpdated});
|
||||
// The implementation for the `total_weight` has been adopted from here:
|
||||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L154-L294
|
||||
// As per the ref link, the `total_weight` value when the `weight` is
|
||||
// `None`, is equal to `total_weight = batch_size - num_ignored_index`,
|
||||
// where `batch_size` is equal to `target.shape[0]` when rank(target) > 0,
|
||||
// otherwise 1. The value `num_ignored_index` is the number of elements of
|
||||
// the `target` tensors that have been ignored.
|
||||
|
||||
if (reduction == torch_upstream::Reduction::None && inputRank == 2) {
|
||||
Value totalWeight = createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||
rewriter.replaceOp(op, {finalRes, totalWeight});
|
||||
return success();
|
||||
}
|
||||
|
||||
Value numIgnoredIndex;
|
||||
if (targetRank == 0) {
|
||||
Value targetVal = rewriter.create<tensor::ExtractOp>(loc, target);
|
||||
numIgnoredIndex = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex);
|
||||
numIgnoredIndex = convertScalarToDtype(rewriter, loc, numIgnoredIndex,
|
||||
ignoreIndex.getType());
|
||||
} else {
|
||||
Value zeroCstInt = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(ignoreIndex.getType()));
|
||||
|
||||
auto opInfo =
|
||||
torch_to_linalg::ReductionOpInfo{/*keepDim=*/false, target, dimSet};
|
||||
numIgnoredIndex = torch_to_linalg::createReductionLinalgGeneric(
|
||||
rewriter, loc, opInfo,
|
||||
/*initElem=*/zeroCstInt,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value targetVal = args[0];
|
||||
Value accumulator = args[1];
|
||||
Value result = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex);
|
||||
result = b.create<arith::AddIOp>(
|
||||
loc,
|
||||
convertScalarToDtype(rewriter, loc, result,
|
||||
ignoreIndex.getType()),
|
||||
accumulator);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
|
||||
numIgnoredIndex =
|
||||
rewriter.create<tensor::ExtractOp>(loc, numIgnoredIndex);
|
||||
}
|
||||
|
||||
Value numtargetElems = getTensorSize(rewriter, loc, target);
|
||||
Value totalWeightVal =
|
||||
rewriter.create<arith::SubIOp>(loc, numtargetElems, numIgnoredIndex);
|
||||
Value totalWeight = createInitTensor(
|
||||
rewriter, loc, {}, elementType,
|
||||
convertScalarToDtype(rewriter, loc, totalWeightVal, elementType));
|
||||
|
||||
rewriter.replaceOp(op, {finalRes, totalWeight});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -29,7 +29,7 @@ class NllLossModule(torch.nn.Module):
|
|||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
ignore_index=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule())
|
||||
|
@ -53,7 +53,7 @@ class NllLossModule_mean(torch.nn.Module):
|
|||
target=y,
|
||||
weight=None,
|
||||
reduction=1,
|
||||
ignore_index=2)[0]
|
||||
ignore_index=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_mean())
|
||||
|
@ -77,7 +77,7 @@ class NllLossModule_sum(torch.nn.Module):
|
|||
target=y,
|
||||
weight=None,
|
||||
reduction=2,
|
||||
ignore_index=2)[0]
|
||||
ignore_index=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_sum())
|
||||
|
@ -101,7 +101,7 @@ class NllLossModule_1D(torch.nn.Module):
|
|||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=2)[0]
|
||||
ignore_index=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_1D())
|
||||
|
@ -126,7 +126,7 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
|||
target=y,
|
||||
weight=None,
|
||||
reduction=0,
|
||||
ignore_index=10)[0]
|
||||
ignore_index=10)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
|
||||
|
|
Loading…
Reference in New Issue