From 2f33f31724ef35e9323ef5e13167f52adab76603 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Wed, 6 Nov 2024 11:34:48 +0800 Subject: [PATCH] [Torch] support AtenNllLossForwardOp decomposition (#3833) --- .../Torch/Transforms/DecomposeComplexOps.cpp | 247 +++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 4 + .../test_suite/nll_loss.py | 99 +++++++ 3 files changed, 349 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 46feee41b..769d8953a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9154,6 +9154,12 @@ public: return rewriter.notifyMatchFailure( op, "Unimplemented: unranked target tensor"); unsigned targetRank = maybeRank.value(); + Value reduction = op.getReduction(); + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d // of the form [minibatch] the cross entropy loss decomposes to the @@ -9184,10 +9190,19 @@ public: loc, rewriter.getI64IntegerAttr(1)); Value logSoftmax = rewriter.create( loc, self.getType(), self, dim, /*dtype=*/noneVal); + + Type secondType; + if (reductionInt == 0) { + secondType = target.getType(); + } else { + auto targetType = dyn_cast(target.getType()); + secondType = targetType.getWithSizesAndDtype({}, targetType.getDtype()); + } + Value nllLoss = rewriter .create( - loc, op.getType(), target.getType(), logSoftmax, target, + loc, op.getType(), secondType, logSoftmax, target, op.getWeight(), op.getReduction(), op.getIgnoreIndex()) ->getResult(0); rewriter.replaceOp(op, nllLoss); @@ -9196,6 +9211,235 @@ public: }; } // namespace +namespace { +// Decompose aten::nll_loss_forward according to : +// torch/_decomp/decompositions.py and +// https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html. +// The (self, target) can be: +// 1. [N, C] and [C], +// or +// 2. [N] or []. +// The weight must be None or 1d where the numel must keep consistent with the +// number of classes. +class DecomposeAtenNllLossForwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNllLossForwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto ctx = op.getContext(); + + auto self = op.getSelf(); + auto target = op.getTarget(); + + auto selfType = dyn_cast(self.getType()); + auto targetType = dyn_cast(target.getType()); + + // constraints. + if (!selfType.hasSizes() || !targetType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "require self and target having sizes!"); + } + + if (!selfType.hasDtype() || !targetType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "require self and target having dtype!"); + } + + auto selfSizes = selfType.getSizes(); + auto targetSizes = targetType.getSizes(); + int64_t selfRank = selfSizes.size(); + int64_t targetRank = targetSizes.size(); + if (selfRank <= 0 or selfRank > 2) { + return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D"); + } + if (targetRank > 1) { + return rewriter.notifyMatchFailure(op, + "target tensor shoule be 0D or 1D!"); + } + + if (selfRank != 1 or targetRank != 0) { + if (!(selfSizes[0] == kUnknownSize and targetSizes[0] == kUnknownSize) and + selfSizes[0] != targetSizes[0]) { + return rewriter.notifyMatchFailure( + op, + "input tensor and target tensor should have the same batch size!"); + } + } + + int64_t numClasses = selfSizes.back(); + auto weight = op.getWeight(); + auto weightT = weight.getType(); + if (!isa(weightT) && numClasses != kUnknownSize) { + auto weightType = dyn_cast(weightT); + if (weightType.areAllSizesKnown()) { + auto weightSizes = weightType.getSizes(); + int64_t weightNumel = 1; + for (size_t i = 0; i < weightSizes.size(); i++) { + weightNumel *= weightSizes[i]; + } + if (weightNumel != numClasses) { + return rewriter.notifyMatchFailure( + op, "weight tensor should be defined either for all classes or " + "no classes!"); + } + } + } + + Value reductionValue = op.getReduction(); + int64_t reduction; + if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } + + // decomposation. + uint64_t channelDim = 1; + if (selfRank < 2) { + channelDim = 0; + } + Value channelDimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(channelDim)); + + auto ignoreIndex = op.getIgnoreIndex(); + Value w; + if (!isa(weightT)) { + if (selfRank > 1) { + auto weightType = dyn_cast(weightT); + auto weightSizes = weightType.getSizes(); + SmallVector newShapeList(selfRank, 1); + newShapeList[channelDim] = weightSizes[0]; + SmallVector newShapeListValue; + for (size_t i = 0; i < newShapeList.size(); ++i) { + newShapeListValue.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(newShapeList[i]))); + } + Value newShape = rewriter.create( + loc, + rewriter.getType( + rewriter.getType()), + newShapeListValue); + auto newType = weightType.getWithSizesAndDtype(newShapeList, + weightType.getDtype()); + w = rewriter.create(loc, newType, weight, newShape); + } else { + w = weight; + } + + self = rewriter.create(loc, self.getType(), self, w); + } + + SmallVector targetDimSizes(targetSizes); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto condType = + ValueTensorType::get(ctx, targetDimSizes, rewriter.getI1Type()); + auto unequalCond = + rewriter.create(loc, condType, target, ignoreIndex); + auto zeroTensorType = + ValueTensorType::get(ctx, {}, rewriter.getIntegerType(64, true)); + Value zeroTensor = + rewriter.create(loc, zeroTensorType, zero); + auto safeTarget = rewriter.create( + loc, target.getType(), unequalCond, target, zeroTensor); + + SmallVector safeTargetShape; + for (size_t i = 0; i < targetSizes.size(); ++i) { + if (channelDim == i) { + safeTargetShape.push_back(1); + } + safeTargetShape.push_back(targetSizes[i]); + } + if (channelDim == safeTargetShape.size()) { + safeTargetShape.push_back(1); + } + + auto gatherType = + ValueTensorType::get(ctx, safeTargetShape, targetType.getDtype()); + auto safeTarget_ = rewriter.create( + loc, gatherType, safeTarget, channelDimValue); + auto falseValue = + rewriter.create(loc, rewriter.getBoolAttr(false)); + auto none = rewriter.create(loc); + auto _gather = rewriter.create( + loc, ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()), + self, channelDimValue, safeTarget_, falseValue); + Value gather = rewriter.create(loc, _gather.getType(), _gather); + auto unequalCondType = cast(unequalCond.getType()); + auto result = rewriter.create( + loc, + unequalCondType.getWithSizesAndDtype(unequalCondType.getSizes(), + selfType.getDtype()), + unequalCond, + rewriter.create( + loc, ValueTensorType::get(ctx, targetSizes, selfType.getDtype()), + gather, channelDimValue), + zeroTensor); + + Value totalWeight; + if (reduction == 0 and selfRank > 1) { + auto zeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value twSize = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + ValueRange({})); + + totalWeight = rewriter.create( + loc, op.getType(1), self, twSize, zeroFloat, none, none, none, none); + rewriter.replaceOp(op, {result, totalWeight}); + + return success(); + } + + if (!isa(weightT)) { + auto wType = cast(w.getType()); + auto newWType = wType.getWithSizesAndDtype(selfSizes, wType.getDtype()); + SmallVector selfSizesValue; + for (size_t i = 0; i < selfSizes.size(); ++i) { + selfSizesValue.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(selfSizes[i]))); + } + auto wSize = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + selfSizesValue); + w = rewriter.create(loc, newWType, w, wSize, falseValue); + auto wSumGather = rewriter.create( + loc, ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w, + channelDimValue, safeTarget_, falseValue); + auto wSumSq = rewriter.create( + loc, ValueTensorType::get(ctx, targetSizes, wType.getDtype()), + wSumGather, channelDimValue); + auto wSum = rewriter.create( + loc, + ValueTensorType::get(ctx, unequalCondType.getSizes(), + wType.getDtype()), + unequalCond, wSumSq, zeroTensor); + + totalWeight = rewriter.create(loc, op.getType(1), wSum, none); + } else { + totalWeight = + rewriter.create(loc, op.getType(1), unequalCond, none); + } + + auto resultSum = + rewriter.create(loc, op.getType(0), result, none); + if (reduction == 1) { + auto resultMean = rewriter.create( + loc, op.getType(0), resultSum, totalWeight); + rewriter.replaceOp(op, {resultMean, totalWeight}); + + return success(); + } + + rewriter.replaceOp(op, {resultSum, totalWeight}); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenBinaryCrossEntropyWithLogitsOp : public OpRewritePattern { @@ -10437,6 +10681,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 28948ad6b..226120302 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3022,9 +3022,13 @@ ONNX_XFAIL_SET = { "NllLossModuleBackward_ignore_index", "NllLossModule_1D_basic", "NllLossModule_basic", + "NllLossStaticModule_basic", + "NllLossStaticModule_weight_basic", "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", + "NllLossStaticModule_mean_basic", "NllLossModule_sum_basic", + "NllLossStaticModule_sum_basic", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py index 675d04249..58c6dfdb9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -36,6 +36,57 @@ def NllLossModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) +class NllLossStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=0, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule()) +def NllLossStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + + +class NllLossStaticModule_weight(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ([3], torch.float32, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y, z): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=z, reduction=2, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule_weight()) +def NllLossStaticModule_weight_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 3), tu.randint(2, low=0, high=3), torch.tensor([0.3, 0.3, 0.4]) + ) + + class NllLossModule_mean(torch.nn.Module): def __init__(self): super().__init__() @@ -60,6 +111,30 @@ def NllLossModule_mean_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) +class NllLossStaticModule_mean(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=1, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule_mean()) +def NllLossStaticModule_mean_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + + class NllLossModule_sum(torch.nn.Module): def __init__(self): super().__init__() @@ -84,6 +159,30 @@ def NllLossModule_sum_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) +class NllLossStaticModule_sum(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=2, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule_sum()) +def NllLossStaticModule_sum_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + + class NllLossModule_1D(torch.nn.Module): def __init__(self): super().__init__()