mirror of https://github.com/llvm/torch-mlir
Decompose aten.fmod into aten.mul,sub,div etc. (#3689)
As titled, create a new decomposition for `aten.fmod.Tensor` to `aten.div`, `aten.trunc`, `aten.mul` and `aten.sub`. Note that we only use `aten.trunc` for floating point operations. This further gets decomposed to `aten.where` etc. by other existing decompositions. This decomposition now makes TOSA pass for a simple model with `aten.fmod` while it makes `stablehlo` fail. For now, we disallow this decomposition for `stablehlo` --------- Co-authored-by: Srinath Avadhanula <srinath.avadhanula@getcruise.com>pull/3654/merge
parent
df6098e43d
commit
0a788e0467
|
@ -7545,6 +7545,44 @@ class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// decompose `fmod(x, y)` to `x - trunc(x/y) * y`
|
||||
class DecomposeAtenFmodTensorOp : public OpRewritePattern<AtenFmodTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFmodTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
Value other = op.getOther();
|
||||
|
||||
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
|
||||
if (!resultTy || !resultTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result must have dtype");
|
||||
}
|
||||
|
||||
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
|
||||
Value div = rewriter.create<AtenDivTensorOp>(loc, resultTy, self, other);
|
||||
Value mul = rewriter.create<AtenMulTensorOp>(loc, resultTy, div, other);
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, mul,
|
||||
alpha);
|
||||
return success();
|
||||
} else if (isa<mlir::FloatType>(resultTy.getDtype())) {
|
||||
Value div = rewriter.create<AtenDivTensorOp>(loc, resultTy, self, other);
|
||||
Value trunc = rewriter.create<AtenTruncOp>(loc, resultTy, div);
|
||||
Value mul = rewriter.create<AtenMulTensorOp>(loc, resultTy, trunc, other);
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, mul,
|
||||
alpha);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
|
||||
// `aten.add.Tensor` op.
|
||||
|
@ -9661,6 +9699,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenRad2degOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFmodTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);
|
||||
|
|
|
@ -1778,6 +1778,9 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
"ElementwiseFmodTensor_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseGeFloatIntScalarModule_basic",
|
||||
"ElementwiseGeFloatScalarModule_basic",
|
||||
"ElementwiseGeIntScalarModule_basic",
|
||||
|
@ -3253,9 +3256,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFmodTensor_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"ElementwiseGeFloatTensorModule_basic",
|
||||
"ElementwiseGeIntTensorModule_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
|
|
|
@ -170,6 +170,7 @@ BACKEND_LEGAL_OPS = {
|
|||
"aten.amin",
|
||||
"aten.randn.generator",
|
||||
"aten.normal_functional",
|
||||
"aten.fmod.Tensor",
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
@ -128,3 +128,46 @@ func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.
|
|||
%1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[],f64>
|
||||
return %1 : !torch.vtensor<[],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @torch.aten.fmod_int(%[[ARG0:.+]]: !torch.vtensor<[?],si32>, %[[ARG1:.+]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
|
||||
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[V0:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
|
||||
// CHECK: %[[V1:.+]] = torch.aten.mul.Tensor %[[V0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
|
||||
// CHECK: %[[V2:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V1]], %[[FLOAT1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si32>, !torch.float -> !torch.vtensor<[?],si32>
|
||||
// CHECK: return %[[V2]] : !torch.vtensor<[?],si32>
|
||||
func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> {
|
||||
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32>
|
||||
return %0 : !torch.vtensor<[?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
|
||||
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[INT5:.+]] = torch.constant.int 5
|
||||
// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
|
||||
// CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
|
||||
// CHECK: %[[V6:.+]] = torch.aten.to.dtype %[[V2]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
|
||||
// CHECK: %[[V7:.+]] = torch.aten.to.dtype %[[V1]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
|
||||
// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V6]], %[[V7]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V9:.+]] = torch.aten.to.dtype %[[V0]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
|
||||
// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V9]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V14:.+]] = torch.aten.mul.Tensor %[[V13]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
|
||||
// CHECK: %[[V15:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V14]], %[[FLOAT1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16>, !torch.float -> !torch.vtensor<[?],f16>
|
||||
// CHECK: return %[[V15]] : !torch.vtensor<[?],f16>
|
||||
func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
|
||||
%0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
|
||||
return %0 : !torch.vtensor<[?],f16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue