From 089018b658869572aeb5e133e1c81d8c814cd3cc Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 7 Feb 2023 03:14:26 +0800 Subject: [PATCH] [MHLO] move AtenTanhOp to ConvertAtenUnaryFPOnlyPatten and add sin/cos/ceil/floor pattern (#1847) --- e2e_testing/xfail_sets.py | 4 ++++ lib/Conversion/TorchToStablehlo/Basic.cpp | 23 +++++------------------ 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index fd9784d6d..c6bb317b8 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -134,6 +134,10 @@ STABLEHLO_PASS_SET = { "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSqrtModule_basic", + "ElementwiseSinModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseCeilModule_basic", + "ElementwiseFloorModule_basic", "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index d82826d71..585c6006f 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -720,23 +720,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// AtenTanhOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenTanhOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); - if (selfTy && selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); - return success(); - } else { - return op.emitError( - "only floating-point datatype legalization currently supported"); - } -} - // ValueTensorLiteralOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1408,6 +1391,11 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); + INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp); + INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); + INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); + INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -1474,7 +1462,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);