From 1f8123b5f0df37216c96a80f21f8d1a38a38513b Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 23 Apr 2024 17:57:12 +0800 Subject: [PATCH] [Stablehlo] support unary ops which promote to floating point (#3209) * promote input to output element-type when lowering to stablehlo, so that it could satisfy stablehlo's type constraints. * split promote-to-fp unary ops from fp-only unary ops. --- lib/Conversion/TorchToStablehlo/Basic.cpp | 61 ++++++++++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 21 ++++---- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 89afc8081..1c4bc2753 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -217,6 +217,37 @@ public: }; } // namespace +// These legalizations are for unary ops with promoting to floating point +// datatypes. +namespace { +template +class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + if (!selfTy) + return op.emitError("only Tensor types supported in StableHLO"); + auto resultTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + if (resultTy.getElementType().template isa()) { + Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); + rewriter.replaceOpWithNewOp(op, resultTy, src); + return success(); + } else { + return op.emitError( + "only result to be floating-point datatype legalization supported"); + } + } +}; +} // namespace + // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { @@ -1888,23 +1919,29 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( target.addIllegalOp(); \ patterns.add>(typeConverter, \ context) - INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); - INSERT_UNARY_FPONLY_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); - INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); - INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp); - INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); - INSERT_UNARY_FPONLY_PATTERN(AtenAsinOp, chlo::AsinOp); - INSERT_UNARY_FPONLY_PATTERN(AtenAcosOp, chlo::AcosOp); - INSERT_UNARY_FPONLY_PATTERN(AtenAtanOp, chlo::AtanOp); #undef INSERT_UNARY_FPONLY_PATTERN +#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp); +#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a11e36060..80ab03566 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -620,17 +620,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAcosIntModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAsinIntModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtan2FloatIntModule_basic", "ElementwiseAtan2TensorFloatModule_basic", "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseAtanTensorIntModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseBitwiseLeftShiftInt32Module_basic", @@ -639,7 +636,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseBitwiseRightShiftInt32Module_basic", "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", - "ElementwiseCosIntModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", @@ -649,22 +645,16 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseLog10Module_basic", "ElementwiseLog2IntModule_basic", "ElementwiseLog2Module_basic", - "ElementwiseLogIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwisePowScalarModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRsqrtIntModule_basic", - "ElementwiseSigmoidIntModule_basic", - "ElementwiseSinIntModule_basic", - "ElementwiseSqrtIntModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", "ElementwiseTernaryModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "ElementwiseUnaryIntModule_basic", "EmptyModule_uint8", "EqIntModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", @@ -1464,6 +1454,17 @@ STABLEHLO_PASS_SET = { "ElementwiseAcosModule_basic", "ElementwiseAsinModule_basic", "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseUnaryIntModule_basic", } STABLEHLO_CRASHING_SET = {