From abb9282524e203471ddbcf3461b41a0df3bc1360 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Tue, 3 Sep 2024 09:13:59 -0700 Subject: [PATCH] Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680) This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For `aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization Patterns as follow: ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.mul.int %1, %const-6 ``` Will be replaced by `torch.aten.mul.int %input, %const-30` And ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.floordiv.int %1, %const-5 ``` Will directly return `%input` This PR also relaxes the `float` type constraint in TorchToTosa for the `AtenRsubScalarOp` conversion. To test: `cmake --build build --target check-torch-mlir-all` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 4 - lib/Dialect/Torch/IR/TorchOps.cpp | 77 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/torch_ods_gen.py | 12 ++- python/torch_mlir/extras/fx_importer.py | 1 + test/Dialect/Torch/canonicalize.mlir | 24 +++++- 7 files changed, 114 insertions(+), 7 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 60d450685..d6eb5a734 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15492,6 +15492,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ @@ -15641,6 +15642,7 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 60f3f3422..5449495d6 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1823,10 +1823,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Rsub"); - if (!isa(selfTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); - Value otherTensor, alphaTensor; if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 23bfaaf3d..3b58baa9e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3508,6 +3508,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } +void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) { + int64_t lhs, rhs; + bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs)); + bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs)); + if (lConstant && rConstant) + return failure(); + if (lConstant || rConstant) { + int64_t firstConstant = lConstant ? lhs : rhs; + Value firstOperand = lConstant ? op.getB() : op.getA(); + if (firstOperand.getDefiningOp() && + firstOperand.getDefiningOp()) { + auto prevMulIntOp = firstOperand.getDefiningOp(); + int64_t prevLhs, prevRhs; + bool prevLConstant = + matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs)); + bool prevRConstant = + matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs)); + if (prevLConstant && prevRConstant) + return failure(); + if ((prevLConstant || prevRConstant) && + prevMulIntOp->hasOneUse() == 1) { + int64_t secondConstant = prevLConstant ? prevLhs : prevRhs; + if (secondConstant == firstConstant) { + rewriter.replaceAllUsesWith( + op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0)); + rewriter.eraseOp(op); + rewriter.eraseOp(prevMulIntOp); + return success(); + } + } + } + } + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenRemainderIntOp //===----------------------------------------------------------------------===// @@ -3799,6 +3837,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) { + int64_t lhs, rhs; + bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs)); + bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs)); + if (lConstant && rConstant) + return failure(); + if (lConstant || rConstant) { + int64_t firstConstant = lConstant ? lhs : rhs; + Value firstOperand = lConstant ? op.getB() : op.getA(); + if (firstOperand.getDefiningOp() && + firstOperand.getDefiningOp()) { + auto prevMulIntOp = firstOperand.getDefiningOp(); + int64_t prevLhs, prevRhs; + bool prevLConstant = + matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs)); + bool prevRConstant = + matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs)); + if (prevLConstant && prevRConstant) + return failure(); + if ((prevLConstant || prevRConstant) && + prevMulIntOp->hasOneUse() == 1) { + auto newConstant = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr( + prevLConstant ? prevLhs * firstConstant + : prevRhs * firstConstant)); + rewriter.replaceOpWithNewOp( + op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0), + newConstant); + rewriter.eraseOp(prevMulIntOp); + return success(); + } + } + } + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenMulFloatOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c2e7d205a..48d87d9c8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2003,6 +2003,7 @@ TOSA_PASS_SET = { "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", + "RsubIntModule_basic", "ScalarTensorDefaultDtypeModule_basic", "ScalarTensorFloat32Module_basic", "ScalarTensorInt32Module_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1d9e8cc3b..fc8f01052 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1086,13 +1086,21 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::le.int : (int, int) -> (bool)", has_folder=True) emit("aten::ne.int : (int, int) -> (bool)", has_folder=True) emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) - emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) + emit( + "aten::floordiv.int : (int, int) -> (int)", + has_folder=True, + has_canonicalizer=True, + ) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) - emit("aten::mul.int : (int, int) -> (int)", has_folder=True) + emit( + "aten::mul.int : (int, int) -> (int)", + has_folder=True, + has_canonicalizer=True, + ) emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 99c8d3cfd..4692d0490 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -267,6 +267,7 @@ PY_BUILTIN_TO_TORCH_OP = { "gt": torch.ops.aten.gt, "mod": torch.ops.aten.fmod, "eq": torch.ops.aten.eq, + "floordiv": torch.ops.aten.floordiv, } # torch with cuda has a __version__ that looks like "2.1.0+cu113", diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a37371428..f13bf60cb 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1168,6 +1168,19 @@ func.func @torch.aten.mul.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize( +// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[CST30:.*]] = torch.constant.int 30 +// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int +// CHECK: return %[[RET]] : !torch.int +func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int { + %cst6 = torch.constant.int 6 + %cst5 = torch.constant.int 5 + %1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int + %ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int + return %ret : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { // CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 // CHECK: return %[[CST30]] : !torch.float @@ -1207,6 +1220,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize( +// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int { +// CHECK: return %[[ARG]] : !torch.int +func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int { + %cst6 = torch.constant.int 6 + %1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int + %ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int + return %ret : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int { // CHECK: %[[CST3:.*]] = torch.constant.int 3 // CHECK: return %[[CST3]] : !torch.int @@ -3122,7 +3145,6 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (! return %1 : !torch.tensor } - // ----- // CHECK-LABEL: @torch.symbolic_int$canonicalize(