From c1967b607fa567990b2658a8b6db8ded65109613 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Tue, 23 Apr 2024 19:06:55 +0800 Subject: [PATCH] [Stablehlo] add AtenLog10Op, AtenLog2Op lowering to stablehlo (#3208) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 45 +++++++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 8 ++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1c4bc2753..0c3cc85b7 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1060,6 +1060,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenLog2Op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + + auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input); + auto log2Op = rewriter.create(op.getLoc(), two); + auto logInputOp = rewriter.create(op.getLoc(), input); + + rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log2Op); + return success(); +} + +// AtenLog10Op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog10Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + + auto outTy = getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + + auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input); + auto log10Op = rewriter.create(op.getLoc(), ten); + auto logInputOp = rewriter.create(op.getLoc(), input); + + rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log10Op); + return success(); +} + // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2028,6 +2071,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenLog10Op); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 80ab03566..11f3d5a83 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -641,10 +641,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseLog10IntModule_basic", - "ElementwiseLog10Module_basic", - "ElementwiseLog2IntModule_basic", - "ElementwiseLog2Module_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwisePowScalarModule_basic", @@ -1046,6 +1042,10 @@ STABLEHLO_PASS_SET = { "ElementwiseGeluModule_basic", "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog2Module_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", "ElementwiseNanToNumModule_Basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic",