From e80054a3cca385bf50760ad43a6d8e8bb799001d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Feb 2024 10:28:23 -0800 Subject: [PATCH] [torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878) Simple folder for limited size aten tensor operations. This is primarily useful for shape computation folding as they unfortunately can use `aten` operators. Add, sub, mul are common examples of these folders. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 3 + lib/Dialect/Torch/IR/TorchOps.cpp | 213 ++++++++++++++++++ .../build_tools/torch_ods_gen.py | 6 +- test/Dialect/Torch/canonicalize.mlir | 5 +- .../Torch/torch-nary-canonicalize.mlir | 143 ++++++++++++ 5 files changed, 364 insertions(+), 6 deletions(-) create mode 100644 test/Dialect/Torch/torch-nary-canonicalize.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 99d00e287..5e4662369 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3790,6 +3790,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3839,6 +3840,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3889,6 +3891,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d831b7076..18c8501df 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1106,6 +1106,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, return success(); } +//===----------------------------------------------------------------------===// +// NAry folder helpers +//===----------------------------------------------------------------------===// + +static bool checkSameDTypes(llvm::ArrayRef attrs) { + bool allFp = true; + bool allInt = true; + + for (auto attr : attrs) { + if (!attr) + return false; + + Type attrty; + if (auto dense = dyn_cast_or_null(attr)) + attrty = dense.getType(); + if (auto fp = dyn_cast_or_null(attr)) + attrty = fp.getType(); + if (auto integer = dyn_cast_or_null(attr)) + attrty = integer.getType(); + if (auto shaped = dyn_cast_or_null(attrty)) + attrty = shaped.getElementType(); + allFp &= isa(attrty); + allInt &= isa(attrty); + } + + return allFp || allInt; +} + +static bool checkAllSplats(llvm::ArrayRef attrs) { + for (auto attr : attrs) { + if (auto dense = dyn_cast_or_null(attr)) { + if (!dense.isSplat()) + return false; + } + } + + return true; +} + +llvm::SmallVector getFoldValueAtIndexFp(llvm::ArrayRef attrs, + int64_t idx = 0) { + llvm::SmallVector splattrs; + + for (auto attr : attrs) { + if (auto dense = dyn_cast(attr)) { + if (dense.isSplat()) { + splattrs.push_back(dense.getSplatValue().convertToDouble()); + } else { + splattrs.push_back(dense.getValues()[idx].convertToDouble()); + } + } else if (auto intattr = dyn_cast(attr)) { + splattrs.push_back(intattr.getValueAsDouble()); + } else { + return {}; + } + } + + return splattrs; +} + +llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, + int64_t bitwidth, + int64_t idx = 0) { + llvm::SmallVector splattrs; + + for (auto attr : attrs) { + bool isunsigned = false; + if (auto dense = dyn_cast(attr)) { + isunsigned = dyn_cast(dense.getElementType()).isUnsigned(); + if (dense.isSplat()) { + splattrs.push_back(dense.getSplatValue()); + } else { + splattrs.push_back(dense.getValues()[idx]); + } + } else if (auto intattr = dyn_cast(attr)) { + isunsigned = cast(intattr.getType()).isUnsigned(); + splattrs.push_back(intattr.getValue()); + } else { + return {}; + } + + auto &apint = splattrs.back(); + if (apint.getBitWidth() < bitwidth) { + if (isunsigned) { + apint = apint.zextOrTrunc(bitwidth); + } else { + apint = apint.sextOrTrunc(bitwidth); + } + } + } + + return splattrs; +} + +using NAryFoldFpOperator = std::function)>; +using NAryFoldIntOperator = std::function)>; + +static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, + NAryFoldFpOperator fpFolder, + NAryFoldIntOperator intFolder) { + constexpr int64_t maxFold = 16; + if (!checkSameDTypes(operands)) + return nullptr; + + auto resultTy = dyn_cast(ty); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) + return nullptr; + + auto dty = resultTy.getDtype(); + auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + + auto fpTy = dyn_cast(dty); + auto intTy = dyn_cast(dty); + if (!fpTy && !intTy) + return nullptr; + + bool allSplats = checkAllSplats(operands); + bool withinMaxFold = + resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold; + + if (!allSplats && !withinMaxFold) + return nullptr; + + // We do not support broadcasting in the non-splat case so validate same + // shaped inputs / outputs: + if (!allSplats) { + auto resultShape = resultBTy.getShape(); + for (int i = 0, s = operands.size(); i < s; ++i) { + if (auto dense = dyn_cast(operands[i])) { + if (dense.isSplat()) + continue; + auto operandShape = cast(dense.getType()).getShape(); + if (operandShape.size() != resultShape.size()) + return nullptr; + for (int i = 0, s = operandShape.size(); i < s; ++i) + if (operandShape[i] != resultShape[i]) + return nullptr; + } + } + } + + const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements(); + + if (fpTy) { + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + auto inputs = getFoldValueAtIndexFp(operands, i); + double fold = fpFolder(inputs); + + APFloat val(fold); + bool unused; + val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + folded.push_back(val); + } + return DenseElementsAttr::get(resultBTy, folded); + } + + if (intTy) { + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + auto inputs = + getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i); + folded.push_back(intFolder(inputs)); + } + return DenseElementsAttr::get(resultBTy, folded); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// @@ -1116,6 +1287,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] + (inputs[1] * inputs[2]); + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] + (inputs[1] * inputs[2]); + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenAddScalarOp //===----------------------------------------------------------------------===// @@ -1136,6 +1321,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] - (inputs[1] * inputs[2]); + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] - (inputs[1] * inputs[2]); + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenSubScalarOp //===----------------------------------------------------------------------===// @@ -1166,6 +1365,20 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return inputs[0] * inputs[1]; + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return inputs[0] * inputs[1]; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenEqTensorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1dc8585d7..bfbebf86b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -340,9 +340,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Elementwise tensor compute ops that don't have the standard mutating # variants. emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 03eeaaeb5..bb5713507 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1916,9 +1916,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[INT6]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> diff --git a/test/Dialect/Torch/torch-nary-canonicalize.mlir b/test/Dialect/Torch/torch-nary-canonicalize.mlir new file mode 100644 index 000000000..b0d22e35d --- /dev/null +++ b/test/Dialect/Torch/torch-nary-canonicalize.mlir @@ -0,0 +1,143 @@ +// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: @fold_aten_add_splat_int +func.func @fold_aten_add_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_int_mismatch +func.func @fold_aten_add_splat_int_mismatch() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi32>) : !torch.vtensor<[4],si32> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si32>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_float +func.func @fold_aten_add_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_float_mismatch +func.func @fold_aten_add_splat_float_mismatch() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf64>) : !torch.vtensor<[4],f64> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f64>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr0_int +func.func @fold_aten_add_arr0_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<[28, 29, 30, 31]> : tensor<4xsi64>) + %cst_7 = torch.vtensor.literal(dense<[6,7,8,9]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr1_int +func.func @fold_aten_add_arr1_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<[27, 29, 31, 33]> : tensor<4xsi64>) + %int2 = torch.constant.int 2 + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<[10,11,12,13]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr0_float +func.func @fold_aten_add_arr0_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<[2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<[6.0, 7.0, 8.0, 9.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr1_float +func.func @fold_aten_add_arr1_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<[2.700000e+01, 2.900000e+01, 3.100000e+01, 3.300000e+01]> : tensor<4xf32>) + %fp_2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<[10.0,11.0,12.0,13.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_sub_splat_int +func.func @fold_aten_sub_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<-15> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int_2 = torch.constant.int 2 + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.sub.Tensor %cst_7, %cst_11, %int_2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_sub_splat_float +func.func @fold_aten_sub_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<-1.500000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %fp_2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.sub.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_mul_splat_int +func.func @fold_aten_mul_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<77> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.mul.Tensor %cst_7, %cst_11: !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_mul_splat_float +func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<7.700000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +}