Decompose aten.fmod into aten.mul,sub,div etc. (#3689)

As titled, create a new decomposition for `aten.fmod.Tensor` to
`aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only
use `aten.trunc` for floating point operations. This further gets
decomposed to `aten.where` etc. by other existing decompositions.

This decomposition now makes TOSA pass for a simple model with
`aten.fmod` while it makes `stablehlo` fail. For now, we disallow this
decomposition for `stablehlo`

---------

Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>
pull/3654/merge
Srinath Avadhanula 2024-09-09 12:00:11 -04:00 committed by GitHub
parent df6098e43d
commit 0a788e0467
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 3 deletions

View File

@ -7545,6 +7545,44 @@ class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
};
} // namespace
namespace {
// decompose `fmod(x, y)` to `x - trunc(x/y) * y`
class DecomposeAtenFmodTensorOp : public OpRewritePattern<AtenFmodTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFmodTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value other = op.getOther();
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result must have dtype");
}
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
Value div = rewriter.create<AtenDivTensorOp>(loc, resultTy, self, other);
Value mul = rewriter.create<AtenMulTensorOp>(loc, resultTy, div, other);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, mul,
alpha);
return success();
} else if (isa<mlir::FloatType>(resultTy.getDtype())) {
Value div = rewriter.create<AtenDivTensorOp>(loc, resultTy, self, other);
Value trunc = rewriter.create<AtenTruncOp>(loc, resultTy, div);
Value mul = rewriter.create<AtenMulTensorOp>(loc, resultTy, trunc, other);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, mul,
alpha);
return success();
}
return failure();
}
};
} // namespace
namespace {
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
// `aten.add.Tensor` op.
@ -9661,6 +9699,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenRad2degOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFmodTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);

View File

@ -1778,6 +1778,9 @@ TOSA_PASS_SET = {
"ElementwiseFloorModule_basic",
"ElementwiseFmaxModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseGeFloatIntScalarModule_basic",
"ElementwiseGeFloatScalarModule_basic",
"ElementwiseGeIntScalarModule_basic",
@ -3253,9 +3256,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"ElementwiseGeFloatTensorModule_basic",
"ElementwiseGeIntTensorModule_basic",
"ElementwiseGeluApproximateTanhModule_basic",

View File

@ -170,6 +170,7 @@ BACKEND_LEGAL_OPS = {
"aten.amin",
"aten.randn.generator",
"aten.normal_functional",
"aten.fmod.Tensor",
],
}

View File

@ -128,3 +128,46 @@ func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.
%1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[],f64>
return %1 : !torch.vtensor<[],f64>
}
// -----
// CHECK: func.func @torch.aten.fmod_int(%[[ARG0:.+]]: !torch.vtensor<[?],si32>, %[[ARG1:.+]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[V0:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
// CHECK: %[[V1:.+]] = torch.aten.mul.Tensor %[[V0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
// CHECK: %[[V2:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V1]], %[[FLOAT1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si32>, !torch.float -> !torch.vtensor<[?],si32>
// CHECK: return %[[V2]] : !torch.vtensor<[?],si32>
func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
return %0 : !torch.vtensor<[?],si32>
}
// -----
// CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[INT5:.+]] = torch.constant.int 5
// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
// CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
// CHECK: %[[V6:.+]] = torch.aten.to.dtype %[[V2]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
// CHECK: %[[V7:.+]] = torch.aten.to.dtype %[[V1]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V6]], %[[V7]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V9:.+]] = torch.aten.to.dtype %[[V0]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V9]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V14:.+]] = torch.aten.mul.Tensor %[[V13]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
// CHECK: %[[V15:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V14]], %[[FLOAT1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16>, !torch.float -> !torch.vtensor<[?],f16>
// CHECK: return %[[V15]] : !torch.vtensor<[?],f16>
func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
return %0 : !torch.vtensor<[?],f16>
}