mirror of https://github.com/llvm/torch-mlir
[stablehlo] add aten.fmod.Tensor op conversion support (#3198)
parent
ea0ecb67be
commit
b6b01602d3
|
@ -1833,6 +1833,37 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenFmodTensorOp
|
||||||
|
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
|
||||||
|
AtenFmodTensorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value lhs = adaptor.getSelf();
|
||||||
|
Value rhs = adaptor.getOther();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
||||||
|
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
||||||
|
|
||||||
|
stablehlo::MulOp mul;
|
||||||
|
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
|
||||||
|
if (isa<mlir::FloatType>(resultType.getElementType())) {
|
||||||
|
// rounding mode is trunc
|
||||||
|
auto sign = rewriter.create<stablehlo::SignOp>(loc, div);
|
||||||
|
auto abs = rewriter.create<stablehlo::AbsOp>(loc, div);
|
||||||
|
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
||||||
|
auto trunc = rewriter.create<stablehlo::MulOp>(loc, sign, floor);
|
||||||
|
mul = rewriter.create<stablehlo::MulOp>(loc, trunc, rhs);
|
||||||
|
} else {
|
||||||
|
mul = rewriter.create<stablehlo::MulOp>(loc, div, rhs);
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::SubtractOp>(op, lhs, mul);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
|
@ -1976,6 +2007,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
|
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||||
|
|
|
@ -650,9 +650,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ElementwiseErfIntModule_basic",
|
"ElementwiseErfIntModule_basic",
|
||||||
"ElementwiseExpm1IntModule_basic",
|
"ElementwiseExpm1IntModule_basic",
|
||||||
"ElementwiseExpm1Module_basic",
|
"ElementwiseExpm1Module_basic",
|
||||||
"ElementwiseFmodTensor_Float_basic",
|
|
||||||
"ElementwiseFmodTensor_Int_Float_basic",
|
|
||||||
"ElementwiseFmodTensor_Int_basic",
|
|
||||||
"ElementwiseLog10IntModule_basic",
|
"ElementwiseLog10IntModule_basic",
|
||||||
"ElementwiseLog10Module_basic",
|
"ElementwiseLog10Module_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
|
@ -1056,6 +1053,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseExpModule_basic",
|
"ElementwiseExpModule_basic",
|
||||||
"ElementwiseFloorIntModule_basic",
|
"ElementwiseFloorIntModule_basic",
|
||||||
"ElementwiseFloorModule_basic",
|
"ElementwiseFloorModule_basic",
|
||||||
|
"ElementwiseFmodTensor_Float_basic",
|
||||||
|
"ElementwiseFmodTensor_Int_Float_basic",
|
||||||
|
"ElementwiseFmodTensor_Int_basic",
|
||||||
"ElementwiseGeluApproximateTanhModule_basic",
|
"ElementwiseGeluApproximateTanhModule_basic",
|
||||||
"ElementwiseGeluModule_basic",
|
"ElementwiseGeluModule_basic",
|
||||||
"ElementwiseLeakyReluStaticModule_basic",
|
"ElementwiseLeakyReluStaticModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue