[MHLO] move AtenTanhOp to ConvertAtenUnaryFPOnlyPatten and add sin/cos/ceil/floor pattern (#1847)

pull/1857/head
Yuanqiang Liu 2023-02-07 03:14:26 +08:00 committed by GitHub
parent c957cebd03
commit 089018b658
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 18 deletions

View File

@ -134,6 +134,10 @@ STABLEHLO_PASS_SET = {
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",

View File

@ -720,23 +720,6 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
return success();
}
// AtenTanhOp
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>();
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<stablehlo::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
} else {
return op.emitError(
"only floating-point datatype legalization currently supported");
}
}
// ValueTensorLiteralOp
template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
@ -1408,6 +1391,11 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
@ -1474,7 +1462,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);