[stablehlo] add aten.fmod.Tensor op conversion support (#3198)

pull/2944/merge
penguin_wwy 2024-04-21 08:39:36 +08:00 committed by GitHub
parent ea0ecb67be
commit b6b01602d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 3 deletions

View File

@ -1833,6 +1833,37 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
return success(); return success();
} }
// AtenFmodTensorOp
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
template <>
LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
AtenFmodTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
stablehlo::MulOp mul;
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
if (isa<mlir::FloatType>(resultType.getElementType())) {
// rounding mode is trunc
auto sign = rewriter.create<stablehlo::SignOp>(loc, div);
auto abs = rewriter.create<stablehlo::AbsOp>(loc, div);
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
auto trunc = rewriter.create<stablehlo::MulOp>(loc, sign, floor);
mul = rewriter.create<stablehlo::MulOp>(loc, trunc, rhs);
} else {
mul = rewriter.create<stablehlo::MulOp>(loc, div, rhs);
}
rewriter.replaceOpWithNewOp<stablehlo::SubtractOp>(op, lhs, mul);
return success();
}
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
@ -1976,6 +2007,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp); INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \

View File

@ -650,9 +650,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseErfIntModule_basic", "ElementwiseErfIntModule_basic",
"ElementwiseExpm1IntModule_basic", "ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic", "ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseLog10IntModule_basic", "ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic", "ElementwiseLog10Module_basic",
"ElementwiseLog2IntModule_basic", "ElementwiseLog2IntModule_basic",
@ -1056,6 +1053,9 @@ STABLEHLO_PASS_SET = {
"ElementwiseExpModule_basic", "ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic", "ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic", "ElementwiseFloorModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseGeluApproximateTanhModule_basic", "ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseGeluModule_basic", "ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic", "ElementwiseLeakyReluStaticModule_basic",