From 5620fe030e75215f65bca2b2ac5a72b87a3adf9d Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 4 Apr 2022 10:57:49 -0700 Subject: [PATCH] Add 1D, weight, and reduction support to nll_loss_backward (#729) This commit adds the following support to the op `nll_loss_backward`: - `input` tensor can be rank-1 - `weight` parameter - `reduction` parameter - `target`, `grad_output`, `total_weight` can be rank-0 - Checks that input tensors are of the expected type --- .../TorchToLinalg/Uncategorized.cpp | 198 +++++++---- .../test_suite/nll_loss.py | 326 ++++++++++++++++++ 2 files changed, 450 insertions(+), 74 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 87f99a78c..920d9710c 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -28,6 +28,13 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +// Check if a ranked-tensor has the specified element type. +template static bool hasElementType(Value tensor) { + auto tensorType = tensor.getType().cast(); + Type tensorElementType = tensorType.getElementType(); + return tensorElementType.isa(); +} + static Value createElementwiseLinalgGeneric( OpBuilder &b, Location loc, ValueRange tensorOperands, Type resultElementType, @@ -1441,11 +1448,6 @@ public: }; } // namespace -// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by: -// for i in range(0, len(input[0])): -// for j in range(0, len(input[1])): -// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0 -// TODO: `weight` and `reduction` operands are still to be taken care of. namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { @@ -1456,89 +1458,137 @@ public: ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + Location loc = op->getLoc(); + Value gradOutput = adaptor.grad_output(); Value input = adaptor.self(); Value target = adaptor.target(); Value weight = adaptor.weight(); - Value gradOutput = adaptor.grad_output(); + bool weightIsNone = op.weight().getType().isa(); + Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.ignore_index()); + Value totalWeight = adaptor.total_weight(); + + auto inputType = input.getType().cast(); + int inputRank = inputType.getRank(); + auto gradOutputType = gradOutput.getType().cast(); + Type resultElementType = gradOutputType.getElementType(); int64_t reduction; if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - // TODO: Handle reduction. - if (reduction != torch_upstream::Reduction::None) + if (!hasElementType(gradOutput) || + !hasElementType(gradOutput) || + (!weightIsNone && !hasElementType(weight))) { return rewriter.notifyMatchFailure( - op, "reduction along dimensions is not supported."); - - // TODO: Incorporate the weight argument. - if (!weight.getType().isa()) - return rewriter.notifyMatchFailure( - op, "Unimplemented, the weight operand is not incorporated."); - - Value ignoreIndex = adaptor.ignore_index(); - Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); - - unsigned inputRank = input.getType().cast().getRank(); - unsigned targetRank = target.getType().cast().getRank(); - - // TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is - // required. - if (inputRank != 2 || targetRank != 1) { - return rewriter.notifyMatchFailure( - op, "expected input and target to be rank 2 and 1 respectively"); + op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of " + "type float"); } + + if (!hasElementType(target)) { + return rewriter.notifyMatchFailure( + op, "`target` must be a tensor of integer type"); + } + + auto outputSize = getTensorSizes(rewriter, loc, input); + Value gradInputTensor = + createZeroInitTensor(rewriter, loc, outputSize, resultElementType); + + auto getAffineMapForSingleElementTensor = [&](Value tensor) { + auto tensorType = tensor.getType().cast(); + SmallVector affineExprs(tensorType.getRank(), + rewriter.getAffineConstantExpr(0)); + return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs, + op->getContext()); + }; + + AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0, + rewriter.getAffineDimExpr(0)); + if (reduction != torch_upstream::Reduction::None || inputRank == 1) + gradOutMap = getAffineMapForSingleElementTensor(gradOutput); + AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0, + rewriter.getAffineDimExpr(0)); + if (inputRank == 1) + targetMap = getAffineMapForSingleElementTensor(target); + AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight); + AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank); + + SmallVector indexingMaps{gradOutMap, targetMap, totalWeightMap, + resultMap}; + SmallVector iteratorTypes(inputRank, + getParallelIteratorTypeName()); + + // The code generation is equivalent to the following pseudo-code: + // + // for batch_index in len(input.size(0)): + // for class_index in len(input.size(1)): + // target_elem = target[batch_index] + // + // if reduction == None: + // grad_out_elem = grad_output[batchIndex] + // else: + // grad_out_elem = grad_output[0] + // + // if reduction == Mean: + // total_weight_elem = total_weight[0] + // grad_out_elem /= total_weight_elem + // + // weight_elem = weight[target_elem] if weight != None else 1 + // + // if target_elem != class_index or target_elem == ignore_index: + // grad_input_elem = -weight_elem * grad_out_elem + // else: + // grad_input_elem = 0 + // grad_input[batch_index, target_elem] = grad_input_elem + // + // NOTE: In the case of not batch dimension, `batch_index` essentially + // becomes zero. + Value gradInput = + rewriter + .create( + loc, gradInputTensor.getType(), + ValueRange{gradOutput, target, totalWeight}, gradInputTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value gradOutElem = args[0]; + Value targetElem = castIntToIndex(b, loc, args[1]); + Value totalWeightElem = args[2]; + Value classIndex = + b.create(loc, inputRank - 1); + + if (reduction == torch_upstream::Reduction::Mean) { + gradOutElem = b.create(loc, gradOutElem, + totalWeightElem); + } + + Value negGradOutElem = + b.create(loc, gradOutElem); + Value weightElem = getConstant(b, loc, 1, resultElementType); + if (!weightIsNone) { + weightElem = + b.create(loc, weight, targetElem); + } + Value weightedNegGradOutElem = + b.create(loc, weightElem, negGradOutElem); + + Value targetNeqClassIndex = b.create( + loc, arith::CmpIPredicate::ne, targetElem, classIndex); + Value targetEqIgnoreIndex = b.create( + loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex); + Value gradInputIsZero = b.create( + loc, targetNeqClassIndex, targetEqIgnoreIndex); + + Value zero = getConstant(b, loc, 0, resultElementType); + Value gradInElem = b.create( + loc, gradInputIsZero, zero, weightedNegGradOutElem); + b.create(loc, gradInElem); + }) + ->getResult(0); + RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); - - Type elementType = resultType.getElementType(); - - // Given there is no reduction `grad_input` size is equal to `input` size. - auto outputSize = getTensorSizes(rewriter, loc, input); - Value initTensor0 = - createZeroInitTensor(rewriter, loc, outputSize, elementType); - Value zeroVal = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); - - SmallVector targetExpr{rewriter.getAffineDimExpr(0)}; - SmallVector resultExpr{rewriter.getAffineDimExpr(0), - rewriter.getAffineDimExpr(1)}; - SmallVector iteratorTypes{getParallelIteratorTypeName(), - getParallelIteratorTypeName()}; - auto indexingMaps = - AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr}); - Value finalRes = - rewriter - .create( - loc, initTensor0.getType(), ValueRange{target, gradOutput}, - initTensor0, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value indTarget = rewriter.create( - loc, rewriter.getIndexType(), args[0]); - Value indJ = rewriter.create(loc, 1); - - // The final result is given by: - // grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0 - Value cmpEq = rewriter.create( - loc, arith::CmpIPredicate::eq, indJ, indTarget); - - // The target index shouldn't be equal to `ignoreIndex`. - Value cmpNe = rewriter.create( - loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget); - Value finalPredicate = - rewriter.create(loc, cmpEq, cmpNe); - Value negate = - rewriter.create(loc, elementType, args[1]); - Value selectFinal = rewriter.create( - loc, finalPredicate, negate, zeroVal); - b.create(loc, selectFinal); - }) - .getResult(0); - - rewriter.replaceOpWithNewOp(op, resultType, finalRes); + rewriter.replaceOpWithNewOp(op, resultType, gradInput); return success(); } }; diff --git a/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/python/torch_mlir_e2e_test/test_suite/nll_loss.py index 4558fdb68..edbb8f444 100644 --- a/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -162,6 +162,37 @@ def NllLossModuleBackward_basic(module, tu: TestUtils): torch.tensor(3.)) +class NllLossModule_backwardWeight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=weight, + reduction=0, + ignore_index=10, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backwardWeight()) +def NllLossModuleBackwardWeight_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.rand(4), torch.tensor(3.)) + + + class NllLossModule_backward_ignore_index(torch.nn.Module): def __init__(self): @@ -190,3 +221,298 @@ class NllLossModule_backward_ignore_index(torch.nn.Module): def NllLossModuleBackward_ignore_index(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), torch.tensor(3.)) + + +class NllLossModule_backwardMean(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=None, + reduction=1, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backwardMean()) +def NllLossModuleBackwardMean_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.tensor(3.)) + + +class NllLossModule_backwardMeanWeight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=weight, + reduction=1, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight()) +def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.rand(4), torch.tensor(3.)) + + +class NllLossModule_backwardSum(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=None, + reduction=2, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backwardSum()) +def NllLossModuleBackwardSum_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.tensor(3.)) + + +class NllLossModule_backwardSumWeight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=weight, + reduction=2, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight()) +def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), + torch.rand(4), torch.tensor(3.)) + + +class NllLossModule_backward1D(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=None, + reduction=0, + ignore_index=10, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward1D()) +def NllLossModuleBackward1D_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), + torch.tensor(3.)) + + +class NllLossModule_backward1DWeight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=weight, + reduction=0, + ignore_index=10, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward1DWeight()) +def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), + torch.rand(3), torch.tensor(3.)) + + +class NllLossModule_backward1DMean(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=None, + reduction=1, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward1DMean()) +def NllLossModuleBackward1DMean_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), + torch.tensor(3.)) + + +class NllLossModule_backward1DMeanWeight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=weight, + reduction=1, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight()) +def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), + torch.rand(3), torch.tensor(3.)) + + +class NllLossModule_backward1DSum(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=None, + reduction=2, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward1DSum()) +def NllLossModuleBackward1DSum_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), + torch.tensor(3.)) + + +class NllLossModule_backward1DSumWeight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ([], torch.float32, True), + ]) + def forward(self, grad_output, input, target, weight, total_weight): + return torch.ops.aten.nll_loss_backward(grad_output, + input, + target=target, + weight=weight, + reduction=2, + ignore_index=1, + total_weight=total_weight) + + +@register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight()) +def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils): + module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), + torch.rand(3), torch.tensor(3.))