From b38585e0773c78e05567e96afc6315733466016e Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 11 Jul 2024 08:46:40 +0800 Subject: [PATCH] [Torch Dialect] fix aten.nan_to_num's decomposition when inf=None (#3530) also add shape infer in decomposition, see https://github.com/llvm/torch-mlir/issues/3312 --- .../Torch/Transforms/DecomposeComplexOps.cpp | 53 ++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 2 + .../test_suite/elementwise.py | 31 +++++++++-- 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 24a79cb0d..33809cce5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3906,37 +3906,50 @@ public: LogicalResult matchAndRewrite(AtenNanToNumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - mlir::FloatType f64Type = rewriter.getF64Type(); Value nan = op.getNan(); Value posinf = op.getPosinf(); Value neginf = op.getNeginf(); - auto baseType = - ValueTensorType::getWithLeastStaticInformation(op.getContext()); - if (dyn_cast_or_null(nan.getDefiningOp())) - nan = rewriter.create( - loc, rewriter.getFloatAttr( - f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); - if (dyn_cast_or_null(posinf.getDefiningOp())) + auto outputType = cast(op.getResult().getType()); + if (!outputType.hasDtype() || + !isa(outputType.getDtype())) { + return rewriter.notifyMatchFailure( + op, "expect output type to have float dtype"); + } + mlir::FloatType outputElementType = + cast(outputType.getDtype()); + + if (isa(nan.getType())) { + nan = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + } + if (isa(posinf.getType())) { posinf = rewriter.create( - loc, rewriter.getFloatAttr( - f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); - if (dyn_cast_or_null(neginf.getDefiningOp())) + loc, rewriter.getF64FloatAttr( + APFloat::getLargest(outputElementType.getFloatSemantics()) + .convertToDouble())); + } + if (isa(neginf.getType())) { neginf = rewriter.create( - loc, - rewriter.getFloatAttr( - f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + loc, rewriter.getF64FloatAttr( + APFloat::getLargest(outputElementType.getFloatSemantics(), + /*Negative=*/true) + .convertToDouble())); + } + + auto compareType = outputType.getWithSizesAndDtype( + outputType.getOptionalSizes(), rewriter.getI1Type()); Value isNan = - rewriter.create(loc, baseType, op.getSelf()); + rewriter.create(loc, compareType, op.getSelf()); Value where = rewriter.create( - loc, baseType, isNan, nan, op.getSelf()); + loc, outputType, isNan, nan, op.getSelf()); Value isposinf = - rewriter.create(loc, baseType, where); + rewriter.create(loc, compareType, where); where = rewriter.create( - loc, baseType, isposinf, posinf, where); + loc, outputType, isposinf, posinf, where); Value isneginf = - rewriter.create(loc, baseType, where); + rewriter.create(loc, compareType, where); rewriter.replaceOpWithNewOp( - op, op.getType(), isneginf, neginf, where); + op, outputType, isneginf, neginf, where); return success(); } }; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c500120a1..504c7ca9d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1029,6 +1029,7 @@ STABLEHLO_PASS_SET = { "ElementwiseLog2Module_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", + "ElementwiseNanToNumWithNoneModule_Basic", "ElementwiseNanToNumModule_Basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", @@ -1761,6 +1762,7 @@ TOSA_PASS_SET = { "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseNanToNumWithNoneModule_Basic", "ElementwiseNanToNumModule_Basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index b448bbaa4..7002cee43 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -610,6 +610,29 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNanToNumWithNoneModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([3, 4], torch.float32, True)]) + def forward(self, a): + return torch.ops.aten.nan_to_num(a) + + +@register_test_case(module_factory=lambda: ElementwiseNanToNumWithNoneModule()) +def ElementwiseNanToNumWithNoneModule_Basic(module, tu: TestUtils): + module.forward( + torch.tensor( + [ + [float("nan"), 0.0, float("nan"), 1.0], + [float("inf"), 2.0, float("inf"), 3.0], + [float("-inf"), -1.0, float("-inf"), 4.0], + ] + ) + ) + + class ElementwiseNanToNumModule(torch.nn.Module): def __init__(self): super().__init__() @@ -617,7 +640,7 @@ class ElementwiseNanToNumModule(torch.nn.Module): @export @annotate_args([None, ([3, 4], torch.float32, True)]) def forward(self, a): - return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + return torch.ops.aten.nan_to_num(a, 0.1, 1.0, -1.0) @register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) @@ -625,9 +648,9 @@ def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): module.forward( torch.tensor( [ - [float("nan"), 0.0, float("nan"), 0.0], - [float("inf"), 0.0, float("inf"), 0.0], - [float("-inf"), 0.0, float("-inf"), 0.0], + [float("nan"), 0.0, float("nan"), 1.0], + [float("inf"), 2.0, float("inf"), 3.0], + [float("-inf"), -1.0, float("-inf"), 4.0], ] ) )