mirror of https://github.com/llvm/torch-mlir
[stablehlo] add aten.fmod.Tensor op conversion support (#3198)
parent
ea0ecb67be
commit
b6b01602d3
|
@ -1833,6 +1833,37 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenFmodTensorOp
|
||||
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
|
||||
AtenFmodTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op->getLoc();
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
||||
|
||||
stablehlo::MulOp mul;
|
||||
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
|
||||
if (isa<mlir::FloatType>(resultType.getElementType())) {
|
||||
// rounding mode is trunc
|
||||
auto sign = rewriter.create<stablehlo::SignOp>(loc, div);
|
||||
auto abs = rewriter.create<stablehlo::AbsOp>(loc, div);
|
||||
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
||||
auto trunc = rewriter.create<stablehlo::MulOp>(loc, sign, floor);
|
||||
mul = rewriter.create<stablehlo::MulOp>(loc, trunc, rhs);
|
||||
} else {
|
||||
mul = rewriter.create<stablehlo::MulOp>(loc, div, rhs);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<stablehlo::SubtractOp>(op, lhs, mul);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
|
@ -1976,6 +2007,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
|
|
|
@ -650,9 +650,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseErfIntModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFmodTensor_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog10Module_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
|
@ -1056,6 +1053,9 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseFloorIntModule_basic",
|
||||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseFmodTensor_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseGeluModule_basic",
|
||||
"ElementwiseLeakyReluStaticModule_basic",
|
||||
|
|
Loading…
Reference in New Issue