mirror of https://github.com/llvm/torch-mlir
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For `aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization Patterns as follow: ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.mul.int %1, %const-6 ``` Will be replaced by `torch.aten.mul.int %input, %const-30` And ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.floordiv.int %1, %const-5 ``` Will directly return `%input` This PR also relaxes the `float` type constraint in TorchToTosa for the `AtenRsubScalarOp` conversion. To test: `cmake --build build --target check-torch-mlir-all`pull/3685/head
parent
70de04a873
commit
b3942ff984
|
@ -15078,6 +15078,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
|
def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
|
||||||
|
@ -15226,6 +15227,7 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
|
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
|
||||||
|
|
|
@ -1823,10 +1823,6 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||||
|
|
||||||
if (!isa<mlir::FloatType>(selfTy.getElementType()))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Only floating-point datatype legalization supported");
|
|
||||||
|
|
||||||
Value otherTensor, alphaTensor;
|
Value otherTensor, alphaTensor;
|
||||||
|
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
||||||
|
|
|
@ -3434,6 +3434,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
|
||||||
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) {
|
||||||
|
int64_t lhs, rhs;
|
||||||
|
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
|
||||||
|
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
|
||||||
|
if (lConstant && rConstant)
|
||||||
|
return failure();
|
||||||
|
if (lConstant || rConstant) {
|
||||||
|
int64_t firstConstant = lConstant ? lhs : rhs;
|
||||||
|
Value firstOperand = lConstant ? op.getB() : op.getA();
|
||||||
|
if (firstOperand.getDefiningOp() &&
|
||||||
|
firstOperand.getDefiningOp<AtenMulIntOp>()) {
|
||||||
|
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
|
||||||
|
int64_t prevLhs, prevRhs;
|
||||||
|
bool prevLConstant =
|
||||||
|
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
|
||||||
|
bool prevRConstant =
|
||||||
|
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
|
||||||
|
if (prevLConstant && prevRConstant)
|
||||||
|
return failure();
|
||||||
|
if ((prevLConstant || prevRConstant) &&
|
||||||
|
prevMulIntOp->hasOneUse() == 1) {
|
||||||
|
int64_t secondConstant = prevLConstant ? prevLhs : prevRhs;
|
||||||
|
if (secondConstant == firstConstant) {
|
||||||
|
rewriter.replaceAllUsesWith(
|
||||||
|
op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0));
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
rewriter.eraseOp(prevMulIntOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenRemainderIntOp
|
// AtenRemainderIntOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -3697,6 +3735,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) {
|
||||||
|
int64_t lhs, rhs;
|
||||||
|
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
|
||||||
|
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
|
||||||
|
if (lConstant && rConstant)
|
||||||
|
return failure();
|
||||||
|
if (lConstant || rConstant) {
|
||||||
|
int64_t firstConstant = lConstant ? lhs : rhs;
|
||||||
|
Value firstOperand = lConstant ? op.getB() : op.getA();
|
||||||
|
if (firstOperand.getDefiningOp() &&
|
||||||
|
firstOperand.getDefiningOp<AtenMulIntOp>()) {
|
||||||
|
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
|
||||||
|
int64_t prevLhs, prevRhs;
|
||||||
|
bool prevLConstant =
|
||||||
|
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
|
||||||
|
bool prevRConstant =
|
||||||
|
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
|
||||||
|
if (prevLConstant && prevRConstant)
|
||||||
|
return failure();
|
||||||
|
if ((prevLConstant || prevRConstant) &&
|
||||||
|
prevMulIntOp->hasOneUse() == 1) {
|
||||||
|
auto newConstant = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op.getLoc(), rewriter.getI64IntegerAttr(
|
||||||
|
prevLConstant ? prevLhs * firstConstant
|
||||||
|
: prevRhs * firstConstant));
|
||||||
|
rewriter.replaceOpWithNewOp<AtenMulIntOp>(
|
||||||
|
op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0),
|
||||||
|
newConstant);
|
||||||
|
rewriter.eraseOp(prevMulIntOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenMulFloatOp
|
// AtenMulFloatOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1963,6 +1963,7 @@ TOSA_PASS_SET = {
|
||||||
"RsubFloatModule_basic",
|
"RsubFloatModule_basic",
|
||||||
"RsubFloatModule_noalpha_basic",
|
"RsubFloatModule_noalpha_basic",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
|
"RsubIntModule_basic",
|
||||||
"ScalarTensorDefaultDtypeModule_basic",
|
"ScalarTensorDefaultDtypeModule_basic",
|
||||||
"ScalarTensorFloat32Module_basic",
|
"ScalarTensorFloat32Module_basic",
|
||||||
"ScalarTensorInt32Module_basic",
|
"ScalarTensorInt32Module_basic",
|
||||||
|
|
|
@ -1060,13 +1060,21 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
|
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
|
||||||
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
|
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
|
||||||
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
|
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
|
||||||
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
|
emit(
|
||||||
|
"aten::floordiv.int : (int, int) -> (int)",
|
||||||
|
has_folder=True,
|
||||||
|
has_canonicalizer=True,
|
||||||
|
)
|
||||||
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
|
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
|
||||||
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
||||||
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
||||||
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
|
emit(
|
||||||
|
"aten::mul.int : (int, int) -> (int)",
|
||||||
|
has_folder=True,
|
||||||
|
has_canonicalizer=True,
|
||||||
|
)
|
||||||
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
|
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
|
||||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||||
emit("aten::log.int : (int) -> (float)")
|
emit("aten::log.int : (int) -> (float)")
|
||||||
|
|
|
@ -279,6 +279,7 @@ PY_BUILTIN_TO_TORCH_OP = {
|
||||||
"gt": torch.ops.aten.gt,
|
"gt": torch.ops.aten.gt,
|
||||||
"mod": torch.ops.aten.fmod,
|
"mod": torch.ops.aten.fmod,
|
||||||
"eq": torch.ops.aten.eq,
|
"eq": torch.ops.aten.eq,
|
||||||
|
"floordiv": torch.ops.aten.floordiv,
|
||||||
}
|
}
|
||||||
|
|
||||||
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
|
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
|
||||||
|
|
|
@ -1168,6 +1168,19 @@ func.func @torch.aten.mul.int() -> !torch.int {
|
||||||
return %ret : !torch.int
|
return %ret : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
|
||||||
|
// CHECK: %[[CST30:.*]] = torch.constant.int 30
|
||||||
|
// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: return %[[RET]] : !torch.int
|
||||||
|
func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
|
||||||
|
%cst6 = torch.constant.int 6
|
||||||
|
%cst5 = torch.constant.int 5
|
||||||
|
%1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int
|
||||||
|
%ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int
|
||||||
|
return %ret : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
|
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
|
||||||
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
|
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
|
||||||
// CHECK: return %[[CST30]] : !torch.float
|
// CHECK: return %[[CST30]] : !torch.float
|
||||||
|
@ -1207,6 +1220,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int {
|
||||||
return %ret : !torch.int
|
return %ret : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
|
||||||
|
// CHECK: return %[[ARG]] : !torch.int
|
||||||
|
func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int {
|
||||||
|
%cst6 = torch.constant.int 6
|
||||||
|
%1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int
|
||||||
|
%ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int
|
||||||
|
return %ret : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int {
|
// CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int {
|
||||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||||
// CHECK: return %[[CST3]] : !torch.int
|
// CHECK: return %[[CST3]] : !torch.int
|
||||||
|
@ -3122,7 +3145,6 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
|
||||||
return %1 : !torch.tensor
|
return %1 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.symbolic_int$canonicalize(
|
// CHECK-LABEL: @torch.symbolic_int$canonicalize(
|
||||||
|
|
Loading…
Reference in New Issue