From dc7a1ff7d9134758128a637dca976f72c2366e59 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Wed, 16 Oct 2024 16:00:58 +0800 Subject: [PATCH] [Torch] add fold logic for some ops (#3794) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 4 + lib/Dialect/Torch/IR/TorchOps.cpp | 134 ++++++++++++++++++ .../build_tools/torch_ods_gen.py | 8 +- .../Torch/torch-nary-canonicalize.mlir | 110 ++++++++++++++ 4 files changed, 254 insertions(+), 2 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b1a670b6d..3ba71e4e3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3425,6 +3425,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -4902,6 +4903,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -12641,6 +12643,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -15334,6 +15337,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 47e77c11f..88e909c14 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1535,6 +1535,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +// ===----------------------------------------------------------------------===// +// AtenRSubScalarOp +// ===----------------------------------------------------------------------===// + +OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[1] - inputs[0] * inputs[2]; + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[1] - inputs[0] * inputs[2]; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenMulTensorOp //===----------------------------------------------------------------------===// @@ -1979,6 +1997,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns( }); } +// ===----------------------------------------------------------------------===// +// AtenDivTensorModeOp +// ===----------------------------------------------------------------------===// + +OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) { + auto resultTy = dyn_cast_or_null(getType()); + if (!resultTy || !resultTy.hasDtype()) { + return nullptr; + } + std::function)> fpFold; + std::function)> intFold; + + auto roundMode = dyn_cast_or_null(adaptor.getRoundingMode()); + auto unsign = false; + if (isa(resultTy.getDtype())) { + unsign = cast(resultTy.getDtype()).isUnsigned(); + } + + fpFold = [roundMode](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + if (!roundMode) { + return (double)inputs[0] / inputs[1]; + } else if (roundMode.getValue().str() == "floor") { + return std::floor((double)inputs[0] / inputs[1]); + } else { + return std::trunc((double)inputs[0] / inputs[1]); + } + }; + + intFold = [unsign, roundMode](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue(); + auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue(); + int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth()); + int64_t res; + if (roundMode.getValue().str() == "floor") { + res = std::floor(lhs / rhs); + } else { + res = std::trunc(lhs / rhs); + } + return APInt(bits, res); + }; + + if (!roundMode) { + return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(), + fpFold, std::nullopt); + } + + return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(), + fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenDivScalarModeOp //===----------------------------------------------------------------------===// @@ -3597,6 +3667,34 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); } +// ===----------------------------------------------------------------------===// +// AtenRemainderScalarOp +// ===----------------------------------------------------------------------===// + +OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) { + auto resultTy = dyn_cast_or_null(getType()); + if (!resultTy || !resultTy.hasDtype()) { + return nullptr; + } + + auto unsign = false; + if (isa(resultTy.getDtype())) { + unsign = cast(resultTy.getDtype()).isUnsigned(); + } + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return std::fmod(inputs[0], inputs[1]); + }; + + auto intFold = [unsign](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]); + return ret; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenAddIntOp //===----------------------------------------------------------------------===// @@ -4229,6 +4327,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenIntTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getA(); + auto dense = dyn_cast_or_null(value); + if (!dense || !dense.isSplat()) { + return nullptr; + } + + auto splat = dense.getSplatValue(); + if (auto intAttr = dyn_cast(splat)) { + auto type = getType(); + if (!isa(type)) { + return nullptr; + } + + if (type.isSignlessInteger()) { + return getI64IntegerAttr(getContext(), intAttr.getInt()); + } else if (type.isSignedInteger()) { + return getI64IntegerAttr(getContext(), intAttr.getSInt()); + } else { + return getI64IntegerAttr(getContext(), intAttr.getUInt()); + } + } + + if (auto floatAttr = dyn_cast(splat)) { + return getI64IntegerAttr( + getContext(), + static_cast(floatAttr.getValue().convertToDouble())); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenFloatTensorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ba56f10fb..84e4f7f15 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -379,6 +379,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # variants. emit_with_mutating_variants( "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", + has_folder=True, has_canonicalizer=True, ) emit_with_mutating_variants( @@ -481,6 +482,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") emit( "aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + has_folder=True, has_canonicalizer=True, ) emit("aten::gelu : (Tensor, str) -> (Tensor)") @@ -928,7 +930,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True ) - emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True) + emit( + "aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True + ) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") @@ -1080,7 +1084,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): has_canonicalizer=True, ) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) - emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) diff --git a/test/Dialect/Torch/torch-nary-canonicalize.mlir b/test/Dialect/Torch/torch-nary-canonicalize.mlir index b0d22e35d..9fb5bac1f 100644 --- a/test/Dialect/Torch/torch-nary-canonicalize.mlir +++ b/test/Dialect/Torch/torch-nary-canonicalize.mlir @@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> { %0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> return %0 : !torch.vtensor<[4],f32> } + +// ----- + +// CHECK-LABEL: @fold_aten_rsub_scalar_int +func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_2 = torch.constant.int 2 + %cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_rsub_scalar_float +func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_2 = torch.constant.float 2.0 + %cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_remainder_scalar_int +func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_2 = torch.constant.int 2 + %cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_remainder_scalar_float +func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_2 = torch.constant.float 2.0 + %cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_int_tensor_int +func.func @fold_aten_int_tensor_int() -> !torch.int { + // CHECK: %int3 = torch.constant.int 3 + %cst_3 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> + %0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: @fold_aten_int_tensor_bool +func.func @fold_aten_int_tensor_bool() -> !torch.int { + // CHECK: %int1 = torch.constant.int 1 + %cst_false = torch.vtensor.literal(dense : tensor) : !torch.vtensor<[],i1> + %0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: @fold_aten_int_tensor_float +func.func @fold_aten_int_tensor_float() -> !torch.int { + // CHECK: %int3 = torch.constant.int 3 + %cst_3 = torch.vtensor.literal(dense<3.1> : tensor) : !torch.vtensor<[],f32> + %0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: @fold_aten_div_tensor_mode_int +func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %trunc = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_div_tensor_mode_float +func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %floor = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_div_tensor_mode_none +func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %none = torch.constant.none + %0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +}