Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)

This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For
`aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization
Patterns as follow:

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.mul.int %1, %const-6
```

Will be replaced by

`torch.aten.mul.int %input, %const-30`


And 

```
%1 = torch.aten.mul.int %input, %const-5
%2 = torch.aten.floordiv.int %1, %const-5
```
Will directly return `%input`


This PR also relaxes the `float` type constraint in TorchToTosa for the
`AtenRsubScalarOp` conversion.



To test:

`cmake --build build --target check-torch-mlir-all`
pull/3685/head
Ze Zhang 2024-09-03 09:13:59 -07:00 committed by GitHub
parent 70de04a873
commit b3942ff984
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 114 additions and 7 deletions

View File

@ -15078,6 +15078,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [
} }
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
} }
def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
@ -15226,6 +15227,7 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
} }
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1;
} }
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [

View File

@ -1823,10 +1823,6 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Rsub"); op, "Only ranked tensor types supported in TOSA Rsub");
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
Value otherTensor, alphaTensor; Value otherTensor, alphaTensor;
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,

View File

@ -3434,6 +3434,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
[](int64_t a, int64_t b) { return std::floor(a / (double)b); }); [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
} }
void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) {
int64_t lhs, rhs;
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
if (lConstant && rConstant)
return failure();
if (lConstant || rConstant) {
int64_t firstConstant = lConstant ? lhs : rhs;
Value firstOperand = lConstant ? op.getB() : op.getA();
if (firstOperand.getDefiningOp() &&
firstOperand.getDefiningOp<AtenMulIntOp>()) {
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
int64_t prevLhs, prevRhs;
bool prevLConstant =
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
bool prevRConstant =
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
if (prevLConstant && prevRConstant)
return failure();
if ((prevLConstant || prevRConstant) &&
prevMulIntOp->hasOneUse() == 1) {
int64_t secondConstant = prevLConstant ? prevLhs : prevRhs;
if (secondConstant == firstConstant) {
rewriter.replaceAllUsesWith(
op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0));
rewriter.eraseOp(op);
rewriter.eraseOp(prevMulIntOp);
return success();
}
}
}
}
return failure();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenRemainderIntOp // AtenRemainderIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -3697,6 +3735,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) {
int64_t lhs, rhs;
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
if (lConstant && rConstant)
return failure();
if (lConstant || rConstant) {
int64_t firstConstant = lConstant ? lhs : rhs;
Value firstOperand = lConstant ? op.getB() : op.getA();
if (firstOperand.getDefiningOp() &&
firstOperand.getDefiningOp<AtenMulIntOp>()) {
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
int64_t prevLhs, prevRhs;
bool prevLConstant =
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
bool prevRConstant =
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
if (prevLConstant && prevRConstant)
return failure();
if ((prevLConstant || prevRConstant) &&
prevMulIntOp->hasOneUse() == 1) {
auto newConstant = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(
prevLConstant ? prevLhs * firstConstant
: prevRhs * firstConstant));
rewriter.replaceOpWithNewOp<AtenMulIntOp>(
op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0),
newConstant);
rewriter.eraseOp(prevMulIntOp);
return success();
}
}
}
return failure();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenMulFloatOp // AtenMulFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1963,6 +1963,7 @@ TOSA_PASS_SET = {
"RsubFloatModule_basic", "RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic", "RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic", "RsubInt0d_NumToTensor_Module_basic",
"RsubIntModule_basic",
"ScalarTensorDefaultDtypeModule_basic", "ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic", "ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic", "ScalarTensorInt32Module_basic",

View File

@ -1060,13 +1060,21 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::le.int : (int, int) -> (bool)", has_folder=True) emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True) emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit(
"aten::floordiv.int : (int, int) -> (int)",
has_folder=True,
has_canonicalizer=True,
)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True) emit(
"aten::mul.int : (int, int) -> (int)",
has_folder=True,
has_canonicalizer=True,
)
emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)") emit("aten::log.int : (int) -> (float)")

View File

@ -279,6 +279,7 @@ PY_BUILTIN_TO_TORCH_OP = {
"gt": torch.ops.aten.gt, "gt": torch.ops.aten.gt,
"mod": torch.ops.aten.fmod, "mod": torch.ops.aten.fmod,
"eq": torch.ops.aten.eq, "eq": torch.ops.aten.eq,
"floordiv": torch.ops.aten.floordiv,
} }
# torch with cuda has a __version__ that looks like "2.1.0+cu113", # torch with cuda has a __version__ that looks like "2.1.0+cu113",

View File

@ -1168,6 +1168,19 @@ func.func @torch.aten.mul.int() -> !torch.int {
return %ret : !torch.int return %ret : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
// CHECK: %[[CST30:.*]] = torch.constant.int 30
// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int
// CHECK: return %[[RET]] : !torch.int
func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
%cst6 = torch.constant.int 6
%cst5 = torch.constant.int 5
%1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int
%ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { // CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 // CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
// CHECK: return %[[CST30]] : !torch.float // CHECK: return %[[CST30]] : !torch.float
@ -1207,6 +1220,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int {
return %ret : !torch.int return %ret : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize(
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
// CHECK: return %[[ARG]] : !torch.int
func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int {
%cst6 = torch.constant.int 6
%1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int
%ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int
return %ret : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int { // CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int {
// CHECK: %[[CST3:.*]] = torch.constant.int 3 // CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: return %[[CST3]] : !torch.int // CHECK: return %[[CST3]] : !torch.int
@ -3122,7 +3145,6 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
return %1 : !torch.tensor return %1 : !torch.tensor
} }
// ----- // -----
// CHECK-LABEL: @torch.symbolic_int$canonicalize( // CHECK-LABEL: @torch.symbolic_int$canonicalize(