mirror of https://github.com/llvm/torch-mlir
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For `aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization Patterns as follow: ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.mul.int %1, %const-6 ``` Will be replaced by `torch.aten.mul.int %input, %const-30` And ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.floordiv.int %1, %const-5 ``` Will directly return `%input` This PR also relaxes the `float` type constraint in TorchToTosa for the `AtenRsubScalarOp` conversion. To test: `cmake --build build --target check-torch-mlir-all`byteir
parent
16bbcb0bef
commit
abb9282524
|
@ -15492,6 +15492,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [
|
|||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
|
||||
|
@ -15641,6 +15642,7 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
|
|||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
|
||||
|
|
|
@ -1823,10 +1823,6 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||
|
||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
|
||||
Value otherTensor, alphaTensor;
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
||||
|
|
|
@ -3508,6 +3508,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
|
|||
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
}
|
||||
|
||||
void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
|
||||
if (lConstant && rConstant)
|
||||
return failure();
|
||||
if (lConstant || rConstant) {
|
||||
int64_t firstConstant = lConstant ? lhs : rhs;
|
||||
Value firstOperand = lConstant ? op.getB() : op.getA();
|
||||
if (firstOperand.getDefiningOp() &&
|
||||
firstOperand.getDefiningOp<AtenMulIntOp>()) {
|
||||
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
|
||||
int64_t prevLhs, prevRhs;
|
||||
bool prevLConstant =
|
||||
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
|
||||
bool prevRConstant =
|
||||
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
|
||||
if (prevLConstant && prevRConstant)
|
||||
return failure();
|
||||
if ((prevLConstant || prevRConstant) &&
|
||||
prevMulIntOp->hasOneUse() == 1) {
|
||||
int64_t secondConstant = prevLConstant ? prevLhs : prevRhs;
|
||||
if (secondConstant == firstConstant) {
|
||||
rewriter.replaceAllUsesWith(
|
||||
op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0));
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.eraseOp(prevMulIntOp);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRemainderIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3799,6 +3837,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
|
||||
if (lConstant && rConstant)
|
||||
return failure();
|
||||
if (lConstant || rConstant) {
|
||||
int64_t firstConstant = lConstant ? lhs : rhs;
|
||||
Value firstOperand = lConstant ? op.getB() : op.getA();
|
||||
if (firstOperand.getDefiningOp() &&
|
||||
firstOperand.getDefiningOp<AtenMulIntOp>()) {
|
||||
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
|
||||
int64_t prevLhs, prevRhs;
|
||||
bool prevLConstant =
|
||||
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
|
||||
bool prevRConstant =
|
||||
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
|
||||
if (prevLConstant && prevRConstant)
|
||||
return failure();
|
||||
if ((prevLConstant || prevRConstant) &&
|
||||
prevMulIntOp->hasOneUse() == 1) {
|
||||
auto newConstant = rewriter.create<Torch::ConstantIntOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(
|
||||
prevLConstant ? prevLhs * firstConstant
|
||||
: prevRhs * firstConstant));
|
||||
rewriter.replaceOpWithNewOp<AtenMulIntOp>(
|
||||
op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0),
|
||||
newConstant);
|
||||
rewriter.eraseOp(prevMulIntOp);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMulFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2003,6 +2003,7 @@ TOSA_PASS_SET = {
|
|||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"RsubIntModule_basic",
|
||||
"ScalarTensorDefaultDtypeModule_basic",
|
||||
"ScalarTensorFloat32Module_basic",
|
||||
"ScalarTensorInt32Module_basic",
|
||||
|
|
|
@ -1086,13 +1086,21 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
|
||||
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
|
||||
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
|
||||
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
|
||||
emit(
|
||||
"aten::floordiv.int : (int, int) -> (int)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
|
||||
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
|
||||
emit(
|
||||
"aten::mul.int : (int, int) -> (int)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
|
||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||
emit("aten::log.int : (int) -> (float)")
|
||||
|
|
|
@ -267,6 +267,7 @@ PY_BUILTIN_TO_TORCH_OP = {
|
|||
"gt": torch.ops.aten.gt,
|
||||
"mod": torch.ops.aten.fmod,
|
||||
"eq": torch.ops.aten.eq,
|
||||
"floordiv": torch.ops.aten.floordiv,
|
||||
}
|
||||
|
||||
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
|
||||
|
|
|
@ -1168,6 +1168,19 @@ func.func @torch.aten.mul.int() -> !torch.int {
|
|||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
|
||||
// CHECK: %[[CST30:.*]] = torch.constant.int 30
|
||||
// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
|
||||
%cst6 = torch.constant.int 6
|
||||
%cst5 = torch.constant.int 5
|
||||
%1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int
|
||||
%ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
|
||||
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
|
||||
// CHECK: return %[[CST30]] : !torch.float
|
||||
|
@ -1207,6 +1220,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int {
|
|||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
|
||||
// CHECK: return %[[ARG]] : !torch.int
|
||||
func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int {
|
||||
%cst6 = torch.constant.int 6
|
||||
%1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int
|
||||
%ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int
|
||||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int {
|
||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK: return %[[CST3]] : !torch.int
|
||||
|
@ -3122,7 +3145,6 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
|
|||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.symbolic_int$canonicalize(
|
||||
|
|
Loading…
Reference in New Issue