From 0a788e0467627bff9990a2bff23320c7d829e5e7 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Mon, 9 Sep 2024 12:00:11 -0400 Subject: [PATCH] Decompose aten.fmod into aten.mul,sub,div etc. (#3689) As titled, create a new decomposition for `aten.fmod.Tensor` to `aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only use `aten.trunc` for floating point operations. This further gets decomposed to `aten.where` etc. by other existing decompositions. This decomposition now makes TOSA pass for a simple model with `aten.fmod` while it makes `stablehlo` fail. For now, we disallow this decomposition for `stablehlo` --------- Co-authored-by: Srinath Avadhanula --- .../Torch/Transforms/DecomposeComplexOps.cpp | 39 +++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 6 +-- projects/pt1/python/torch_mlir/torchscript.py | 1 + test/Dialect/Torch/decompose-complex-ops.mlir | 43 +++++++++++++++++++ 4 files changed, 86 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index db5b7f246..f354374fe 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7545,6 +7545,44 @@ class DecomposeAtenTruncOp : public OpRewritePattern { }; } // namespace +namespace { +// decompose `fmod(x, y)` to `x - trunc(x/y) * y` +class DecomposeAtenFmodTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFmodTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + if (isa(resultTy.getDtype())) { + Value div = rewriter.create(loc, resultTy, self, other); + Value mul = rewriter.create(loc, resultTy, div, other); + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + rewriter.replaceOpWithNewOp(op, resultTy, self, mul, + alpha); + return success(); + } else if (isa(resultTy.getDtype())) { + Value div = rewriter.create(loc, resultTy, self, other); + Value trunc = rewriter.create(loc, resultTy, div); + Value mul = rewriter.create(loc, resultTy, trunc, other); + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + rewriter.replaceOpWithNewOp(op, resultTy, self, mul, + alpha); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. @@ -9661,6 +9699,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b9f038146..80831d8ea 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1778,6 +1778,9 @@ TOSA_PASS_SET = { "ElementwiseFloorModule_basic", "ElementwiseFmaxModule_basic", "ElementwiseFminModule_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", "ElementwiseGeFloatIntScalarModule_basic", "ElementwiseGeFloatScalarModule_basic", "ElementwiseGeIntScalarModule_basic", @@ -3253,9 +3256,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "ElementwiseFmodTensor_Float_basic", - "ElementwiseFmodTensor_Int_Float_basic", - "ElementwiseFmodTensor_Int_basic", "ElementwiseGeFloatTensorModule_basic", "ElementwiseGeIntTensorModule_basic", "ElementwiseGeluApproximateTanhModule_basic", diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 585fa94d0..561b4fc2b 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -170,6 +170,7 @@ BACKEND_LEGAL_OPS = { "aten.amin", "aten.randn.generator", "aten.normal_functional", + "aten.fmod.Tensor", ], } diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 86c0a07ad..f938a2637 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -128,3 +128,46 @@ func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch. %1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list, !torch.none -> !torch.vtensor<[],f64> return %1 : !torch.vtensor<[],f64> } + +// ----- + +// CHECK: func.func @torch.aten.fmod_int(%[[ARG0:.+]]: !torch.vtensor<[?],si32>, %[[ARG1:.+]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> { +// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 +// CHECK: %[[V0:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32> +// CHECK: %[[V1:.+]] = torch.aten.mul.Tensor %[[V0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32> +// CHECK: %[[V2:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V1]], %[[FLOAT1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si32>, !torch.float -> !torch.vtensor<[?],si32> +// CHECK: return %[[V2]] : !torch.vtensor<[?],si32> +func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> { + %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32> + return %0 : !torch.vtensor<[?],si32> +} + +// ----- + +// CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> { +// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 +// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[INT5:.+]] = torch.constant.int 5 +// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[INT0:.+]] = torch.constant.int 0 +// CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1> +// CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1> +// CHECK: %[[V6:.+]] = torch.aten.to.dtype %[[V2]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> +// CHECK: %[[V7:.+]] = torch.aten.to.dtype %[[V1]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> +// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V6]], %[[V7]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V9:.+]] = torch.aten.to.dtype %[[V0]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> +// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V9]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V14:.+]] = torch.aten.mul.Tensor %[[V13]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V15:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V14]], %[[FLOAT1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16>, !torch.float -> !torch.vtensor<[?],f16> +// CHECK: return %[[V15]] : !torch.vtensor<[?],f16> +func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> { + %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> + return %0 : !torch.vtensor<[?],f16> +}