[Stablehlo] support unary ops which promote to floating point (#3209)

* promote input to output element-type when lowering to stablehlo, so
that it could satisfy stablehlo's type constraints.
* split promote-to-fp unary ops from fp-only unary ops.
pull/3192/head
Yuanqiang Liu 2024-04-23 17:57:12 +08:00 committed by GitHub
parent 797e4cd395
commit 1f8123b5f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 60 additions and 22 deletions

View File

@ -217,6 +217,37 @@ public:
};
} // namespace
// These legalizations are for unary ops with promoting to floating point
// datatypes.
namespace {
template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>();
if (!selfTy)
return op.emitError("only Tensor types supported in StableHLO");
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
if (resultTy.getElementType().template isa<mlir::FloatType>()) {
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
return success();
} else {
return op.emitError(
"only result to be floating-point datatype legalization supported");
}
}
};
} // namespace
// aten.ones & aten.zeros
// Ref: Error checking based on the Torch to TOSA lowering
namespace {
@ -1888,23 +1919,29 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
context)
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp);
INSERT_UNARY_FPONLY_PATTERN(AtenAsinOp, chlo::AsinOp);
INSERT_UNARY_FPONLY_PATTERN(AtenAcosOp, chlo::AcosOp);
INSERT_UNARY_FPONLY_PATTERN(AtenAtanOp, chlo::AtanOp);
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, StablehloOp>>( \
typeConverter, context)
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp);
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp);
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \

View File

@ -620,17 +620,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"DiagonalModule_with_offset",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseAcosIntModule_basic",
"ElementwiseAcoshIntModule_basic",
"ElementwiseAcoshModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAsinhIntModule_basic",
"ElementwiseAsinhModule_basic",
"ElementwiseAtan2FloatIntModule_basic",
"ElementwiseAtan2TensorFloatModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseBitwiseLeftShiftInt32Module_basic",
@ -639,7 +636,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseBitwiseRightShiftInt32Module_basic",
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
@ -649,22 +645,16 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseLog10Module_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLog2Module_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwisePowScalarModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseSqrtIntModule_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseTernaryModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"EmptyModule_uint8",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
@ -1464,6 +1454,17 @@ STABLEHLO_PASS_SET = {
"ElementwiseAcosModule_basic",
"ElementwiseAsinModule_basic",
"ElementwiseAtanTensorFloatModule_basic",
"ElementwiseAcosIntModule_basic",
"ElementwiseAsinIntModule_basic",
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseExpIntModule_basic",
"ElementwiseLogIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseSqrtIntModule_basic",
"ElementwiseUnaryIntModule_basic",
}
STABLEHLO_CRASHING_SET = {