mirror of https://github.com/llvm/torch-mlir
[Stablehlo] add AtenLog10Op, AtenLog2Op lowering to stablehlo (#3208)
parent
1f8123b5f0
commit
c1967b607f
|
@ -1060,6 +1060,49 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenLog2Op
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
|
AtenLog2Op op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
|
}
|
||||||
|
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||||
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||||
|
|
||||||
|
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||||
|
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
||||||
|
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log2Op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// AtenLog10Op
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
||||||
|
AtenLog10Op op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!inputTy) {
|
||||||
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||||
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||||
|
|
||||||
|
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||||
|
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
||||||
|
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log10Op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// AtenErfOp
|
// AtenErfOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||||
|
@ -2028,6 +2071,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenLog10Op);
|
||||||
INSERT_ATENOP_PATTERN(AtenErfOp);
|
INSERT_ATENOP_PATTERN(AtenErfOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||||
|
|
||||||
|
|
|
@ -641,10 +641,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseLog10IntModule_basic",
|
|
||||||
"ElementwiseLog10Module_basic",
|
|
||||||
"ElementwiseLog2IntModule_basic",
|
|
||||||
"ElementwiseLog2Module_basic",
|
|
||||||
"ElementwiseLogitModule_basic",
|
"ElementwiseLogitModule_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwisePowScalarModule_basic",
|
"ElementwisePowScalarModule_basic",
|
||||||
|
@ -1046,6 +1042,10 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseGeluModule_basic",
|
"ElementwiseGeluModule_basic",
|
||||||
"ElementwiseLeakyReluStaticModule_basic",
|
"ElementwiseLeakyReluStaticModule_basic",
|
||||||
"ElementwiseLogModule_basic",
|
"ElementwiseLogModule_basic",
|
||||||
|
"ElementwiseLog10Module_basic",
|
||||||
|
"ElementwiseLog2Module_basic",
|
||||||
|
"ElementwiseLog10IntModule_basic",
|
||||||
|
"ElementwiseLog2IntModule_basic",
|
||||||
"ElementwiseNanToNumModule_Basic",
|
"ElementwiseNanToNumModule_Basic",
|
||||||
"ElementwiseNeFloatTensorStaticModule_basic",
|
"ElementwiseNeFloatTensorStaticModule_basic",
|
||||||
"ElementwiseNeIntTensorStaticModule_basic",
|
"ElementwiseNeIntTensorStaticModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue