[Stablehlo] add AtenLog10Op, AtenLog2Op lowering to stablehlo (#3208)

pull/3192/head
Xinyu Yang 2024-04-23 19:06:55 +08:00 committed by GitHub
parent 1f8123b5f0
commit c1967b607f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 4 deletions

View File

@ -1060,6 +1060,49 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
return success(); return success();
} }
// AtenLog2Op
template <>
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
AtenLog2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
if (!inputTy) {
return op.emitError("only ranked tensor type is supported.");
}
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log2Op);
return success();
}
// AtenLog10Op
template <>
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
AtenLog10Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
if (!inputTy) {
return op.emitError("only ranked tensor type is supported.");
}
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log10Op);
return success();
}
// AtenErfOp // AtenErfOp
template <> template <>
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
@ -2028,6 +2071,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenErfOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);

View File

@ -641,10 +641,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic", "ElementwiseDequantizePerTensorModule_basic",
"ElementwiseErfIntModule_basic", "ElementwiseErfIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseLog2Module_basic",
"ElementwiseLogitModule_basic", "ElementwiseLogitModule_basic",
"ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexModule_basic",
"ElementwisePowScalarModule_basic", "ElementwisePowScalarModule_basic",
@ -1046,6 +1042,10 @@ STABLEHLO_PASS_SET = {
"ElementwiseGeluModule_basic", "ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic", "ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"ElementwiseLog10Module_basic",
"ElementwiseLog2Module_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
"ElementwiseNanToNumModule_Basic", "ElementwiseNanToNumModule_Basic",
"ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic",