mirror of https://github.com/llvm/torch-mlir
[Stablehlo] add AtenLog10Op, AtenLog2Op lowering to stablehlo (#3208)
parent
1f8123b5f0
commit
c1967b607f
|
@ -1060,6 +1060,49 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenLog2Op
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||
AtenLog2Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return op.emitError("only ranked tensor type is supported.");
|
||||
}
|
||||
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
|
||||
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
||||
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log2Op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenLog10Op
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
||||
AtenLog10Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return op.emitError("only ranked tensor type is supported.");
|
||||
}
|
||||
|
||||
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
|
||||
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
||||
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log10Op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenErfOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||
|
@ -2028,6 +2071,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||
INSERT_ATENOP_PATTERN(AtenLog10Op);
|
||||
INSERT_ATENOP_PATTERN(AtenErfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||
|
||||
|
|
|
@ -641,10 +641,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog10Module_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseLog2Module_basic",
|
||||
"ElementwiseLogitModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwisePowScalarModule_basic",
|
||||
|
@ -1046,6 +1042,10 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseGeluModule_basic",
|
||||
"ElementwiseLeakyReluStaticModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseLog10Module_basic",
|
||||
"ElementwiseLog2Module_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseNanToNumModule_Basic",
|
||||
"ElementwiseNeFloatTensorStaticModule_basic",
|
||||
"ElementwiseNeIntTensorStaticModule_basic",
|
||||
|
|
Loading…
Reference in New Issue