mirror of https://github.com/llvm/torch-mlir
[MHLO] move AtenTanhOp to ConvertAtenUnaryFPOnlyPatten and add sin/cos/ceil/floor pattern (#1847)
parent
c957cebd03
commit
089018b658
|
@ -134,6 +134,10 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseRsqrtModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseSqrtModule_basic",
|
||||
"ElementwiseSinModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
|
|
|
@ -720,23 +720,6 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenTanhOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||
AtenTanhOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<stablehlo::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
// ValueTensorLiteralOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||
|
@ -1408,6 +1391,11 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
|
@ -1474,7 +1462,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
|
|
Loading…
Reference in New Issue