From 977b1b03ea7dfb68f0971f1f7313ec6856236380 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Tue, 30 Nov 2021 18:35:33 +0530 Subject: [PATCH] Add aten::nll_loss_forward op lowering. The op lowering has been added as a part of `torch-lower-to-linalg` pass. This takes care of ignore_index but the weight and reduction operand is still to be accounted for. Signed-off-by: Prashant Kumar --- e2e_testing/torchscript/main.py | 1 + e2e_testing/torchscript/nll_loss.py | 62 ++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 19 ++++ .../TorchToLinalg/TorchToLinalg.cpp | 98 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 39 ++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + 6 files changed, 220 insertions(+) create mode 100644 e2e_testing/torchscript/nll_loss.py diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 39ea40243..61b9ded03 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -43,6 +43,7 @@ from . import view from . import scalar from . import squeeze from . import slice_like +from . import nll_loss def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/e2e_testing/torchscript/nll_loss.py b/e2e_testing/torchscript/nll_loss.py new file mode 100644 index 000000000..6afcb5481 --- /dev/null +++ b/e2e_testing/torchscript/nll_loss.py @@ -0,0 +1,62 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + + +class NllLossModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ]) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward(self=x, + target=y, + weight=None, + reduction=0, + ignore_index=2)[0] + + +@register_test_case(module_factory=lambda: NllLossModule()) +def NllLossModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), torch.tensor([0, 1])) + + +class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ]) + # None of the index is ignored here, since the ignored index is out of bounds. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward(self=x, + target=y, + weight=None, + reduction=0, + ignore_index=10)[0] + + +@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds()) +def NllLossModule_ignore_index(module, tu: TestUtils): + module.forward(tu.rand(2, 3), torch.tensor([0, 1])) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 69b5f7697..d7858a606 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1573,6 +1573,25 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ let assemblyFormat = "$self `,` $dtype attr-dict `:` type($self) `,` type($dtype) `->` type($result)"; } +def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index + ); + let results = (outs + AnyTorchTensorType:$output, + AnyTorchTensorType:$total_weight + ); + let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)"; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement ]> { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 451598901..d485fe7a1 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1167,6 +1167,102 @@ public: }; } // namespace +// Given `input`, `target`, `nll_loss_forward` is given by: +// for i in range(0, len(target)): +// indi = target[i]; +// nll_loss_forward[i] = -(input[i][indi]); +// TODO: `weight` and `reduction` operands are still to be taken care of. +namespace { +class ConvertAtenNllLossForwardOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + Value input = adaptor.self(); + Value target = adaptor.target(); + Value weight = adaptor.weight(); + + int64_t reduce_dim; + if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim))) + return rewriter.notifyMatchFailure(op, "dim must be constant"); + + // TODO: Handle reduction. + if (reduce_dim != 0) + return rewriter.notifyMatchFailure( + op, "reduction along dimensions is not supported."); + + // TODO: Incorporate the weight argument. + if (!weight.getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented, the weight operand is not incorporated."); + + Value ignoreIndex = adaptor.ignore_index(); + Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); + + unsigned inputRank = input.getType().cast().getRank(); + unsigned targetRank = target.getType().cast().getRank(); + + // TODO: Cases with targetRank != 1 where `Mean` reduction is required. + if (inputRank != 2 || targetRank != 1) { + return rewriter.notifyMatchFailure( + op, "expected input and target to be rank 2 and 1 respectively"); + } + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + Type elementType = resultType.getElementType(); + + Value targetDim = getDimOp(rewriter, loc, target, 0); + Value initTensor0 = + createZeroInitTensor(rewriter, loc, {targetDim}, elementType); + Value zeroVal = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + + SmallVector targetExpr; + targetExpr.push_back(rewriter.getAffineDimExpr(0)); + SmallVector iteratorTypes{getParallelIteratorTypeName()}; + auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr}); + Value finalRes = + rewriter + .create( + loc, resultType, ValueRange{target}, initTensor0, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value indTarget = rewriter.create( + loc, rewriter.getIndexType(), args[0]); + Value indI = rewriter.create(loc, 0); + + // The final result is given by: + // final_res = (indI == ignoreIndexVal) ? 0 : + // input[indI][IndTarget] + Value cmpEq = rewriter.create( + loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal); + Value result = rewriter.create( + loc, input, ValueRange{indI, indTarget}); + Value negate = + rewriter.create(loc, elementType, result); + Value selectFinal = rewriter.create( + loc, cmpEq, zeroVal, negate); + b.create(loc, selectFinal); + }) + .getResult(0); + + // TODO: Update the second result tensor. + Value weightUpdated = + createZeroInitTensor(rewriter, loc, {}, elementType); + rewriter.replaceOp(op, {finalRes, weightUpdated}); + return success(); + } +}; +} // namespace + namespace { // See comments at in convertMmOp and the heading for this section for general // considerations. This function needs to be auto-generated. @@ -3372,6 +3468,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 6a33c91d4..40e9ab31b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -454,8 +454,11 @@ public: return visitAtenAddCLikeOp(op, operands); } else if (auto scalarOp = dyn_cast(op)) { return visitBinaryScalarOp(scalarOp); + }else if (auto nllForwardOp = dyn_cast(op)) { + return visitAtenNllLossForwardOp(nllForwardOp, operands); } + // Otherwise, this is an unknown operation. Just mark all results as // having reached a pessimistic fixpoint. return markAllPessimisticFixpoint(op->getResults()); @@ -580,6 +583,10 @@ private: ChangeResult visitAten_SoftmaxOp(Aten_SoftmaxOp op, ArrayRef *> operands); + + ChangeResult + visitAtenNllLossForwardOp(AtenNllLossForwardOp op, + ArrayRef *> operands); }; } // namespace @@ -927,6 +934,38 @@ ChangeResult TypeAnalyzer::visitAtenSqueezeOp( return getLatticeElement(op.getResult()).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp( + AtenNllLossForwardOp op, + ArrayRef *> operands) { + auto self = operands[0]->getValue(); + auto outputKnowledge = + ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + + // Contains Knowledge of shape and dtype for the 1st result. + outputKnowledge.dtype = self.dtype; + int64_t reduction; + unsigned resultRank = self.sizes.size(); + + // Contains Knowledge of shape and dtype for the 2nd result. + auto totalWeightKnowledge = + ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); + totalWeightKnowledge.dtype = self.dtype; + totalWeightKnowledge.sizes.resize(0, kUnknownSize); + totalWeightKnowledge.hasSizes = true; + + if (self.hasSizes && + matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) { + // reduction == 1 means reduce 1st dim. + resultRank = reduction == 1 ? resultRank - 1 : resultRank; + } + outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize); + outputKnowledge.hasSizes = true; + auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge); + resultLattice |= + getLatticeElement(op.getResult(1)).join(totalWeightKnowledge); + return resultLattice; +} + ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp( AtenUnsqueezeOp op, ArrayRef *> operands) { auto operand = operands[0]->getValue(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 419be419c..67f4dd8a1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -523,6 +523,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::sqrt : (Tensor) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") + emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") # Misc tensor ops. emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")