[torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878)

Simple folder for limited size aten tensor operations. This is primarily
useful for shape computation folding as they unfortunately can use
`aten` operators. Add, sub, mul are common examples of these folders.
pull/2924/merge
Rob Suderman 2024-02-19 10:28:23 -08:00 committed by GitHub
parent cea51897a5
commit e80054a3cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 364 additions and 6 deletions

View File

@ -3790,6 +3790,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
@ -3839,6 +3840,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
printDefaultTorchOp(printer, *this, 3, 1); printDefaultTorchOp(printer, *this, 3, 1);
} }
}]; }];
let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
@ -3889,6 +3891,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
printDefaultTorchOp(printer, *this, 3, 1); printDefaultTorchOp(printer, *this, 3, 1);
} }
}]; }];
let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }

View File

@ -1106,6 +1106,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// NAry folder helpers
//===----------------------------------------------------------------------===//
static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
bool allFp = true;
bool allInt = true;
for (auto attr : attrs) {
if (!attr)
return false;
Type attrty;
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr))
attrty = dense.getType();
if (auto fp = dyn_cast_or_null<mlir::FloatAttr>(attr))
attrty = fp.getType();
if (auto integer = dyn_cast_or_null<mlir::IntegerAttr>(attr))
attrty = integer.getType();
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
attrty = shaped.getElementType();
allFp &= isa<mlir::FloatType>(attrty);
allInt &= isa<mlir::IntegerType>(attrty);
}
return allFp || allInt;
}
static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
for (auto attr : attrs) {
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
if (!dense.isSplat())
return false;
}
}
return true;
}
llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
int64_t idx = 0) {
llvm::SmallVector<double> splattrs;
for (auto attr : attrs) {
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
} else {
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
}
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
splattrs.push_back(intattr.getValueAsDouble());
} else {
return {};
}
}
return splattrs;
}
llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
int64_t bitwidth,
int64_t idx = 0) {
llvm::SmallVector<APInt> splattrs;
for (auto attr : attrs) {
bool isunsigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
} else {
splattrs.push_back(dense.getValues<APInt>()[idx]);
}
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
splattrs.push_back(intattr.getValue());
} else {
return {};
}
auto &apint = splattrs.back();
if (apint.getBitWidth() < bitwidth) {
if (isunsigned) {
apint = apint.zextOrTrunc(bitwidth);
} else {
apint = apint.sextOrTrunc(bitwidth);
}
}
}
return splattrs;
}
using NAryFoldFpOperator = std::function<double(ArrayRef<double>)>;
using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;
static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
NAryFoldFpOperator fpFolder,
NAryFoldIntOperator intFolder) {
constexpr int64_t maxFold = 16;
if (!checkSameDTypes(operands))
return nullptr;
auto resultTy = dyn_cast<ValueTensorType>(ty);
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
return nullptr;
auto dty = resultTy.getDtype();
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);
auto fpTy = dyn_cast<mlir::FloatType>(dty);
auto intTy = dyn_cast<mlir::IntegerType>(dty);
if (!fpTy && !intTy)
return nullptr;
bool allSplats = checkAllSplats(operands);
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold;
if (!allSplats && !withinMaxFold)
return nullptr;
// We do not support broadcasting in the non-splat case so validate same
// shaped inputs / outputs:
if (!allSplats) {
auto resultShape = resultBTy.getShape();
for (int i = 0, s = operands.size(); i < s; ++i) {
if (auto dense = dyn_cast<DenseElementsAttr>(operands[i])) {
if (dense.isSplat())
continue;
auto operandShape = cast<ShapedType>(dense.getType()).getShape();
if (operandShape.size() != resultShape.size())
return nullptr;
for (int i = 0, s = operandShape.size(); i < s; ++i)
if (operandShape[i] != resultShape[i])
return nullptr;
}
}
}
const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements();
if (fpTy) {
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs = getFoldValueAtIndexFp(operands, i);
double fold = fpFolder(inputs);
APFloat val(fold);
bool unused;
val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
folded.push_back(val);
}
return DenseElementsAttr::get(resultBTy, folded);
}
if (intTy) {
llvm::SmallVector<APInt> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs =
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
folded.push_back(intFolder(inputs));
}
return DenseElementsAttr::get(resultBTy, folded);
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenAddTensorOp // AtenAddTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1116,6 +1287,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenAddScalarOp // AtenAddScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1136,6 +1321,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSubScalarOp // AtenSubScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1166,6 +1365,20 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenEqTensorOp // AtenEqTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -340,9 +340,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# Elementwise tensor compute ops that don't have the standard mutating # Elementwise tensor compute ops that don't have the standard mutating
# variants. # variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True)
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)

View File

@ -1916,9 +1916,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor
} }
// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6]] = torch.constant.int 6 // CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> // CHECK: return %[[INT6]] : !torch.vtensor<[],si64>
// CHECK: return %[[PR0]] : !torch.vtensor<[],si64>
func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64> %0 = torch.vtensor.literal(dense<2> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>

View File

@ -0,0 +1,143 @@
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: @fold_aten_add_splat_int
func.func @fold_aten_add_splat_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%int2 = torch.constant.int 2
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_add_splat_int_mismatch
func.func @fold_aten_add_splat_int_mismatch() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi32>) : !torch.vtensor<[4],si32>
%int2 = torch.constant.int 2
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si32>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_add_splat_float
func.func @fold_aten_add_splat_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>)
%int2 = torch.constant.float 2.0
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_add_splat_float_mismatch
func.func @fold_aten_add_splat_float_mismatch() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>)
%int2 = torch.constant.float 2.0
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf64>) : !torch.vtensor<[4],f64>
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f64>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_add_arr0_int
func.func @fold_aten_add_arr0_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<[28, 29, 30, 31]> : tensor<4xsi64>)
%cst_7 = torch.vtensor.literal(dense<[6,7,8,9]> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%int2 = torch.constant.int 2
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_add_arr1_int
func.func @fold_aten_add_arr1_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<[27, 29, 31, 33]> : tensor<4xsi64>)
%int2 = torch.constant.int 2
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_11 = torch.vtensor.literal(dense<[10,11,12,13]> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_add_arr0_float
func.func @fold_aten_add_arr0_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<[2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> : tensor<4xf32>)
%int2 = torch.constant.float 2.0
%cst_7 = torch.vtensor.literal(dense<[6.0, 7.0, 8.0, 9.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_add_arr1_float
func.func @fold_aten_add_arr1_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<[2.700000e+01, 2.900000e+01, 3.100000e+01, 3.300000e+01]> : tensor<4xf32>)
%fp_2 = torch.constant.float 2.0
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_11 = torch.vtensor.literal(dense<[10.0,11.0,12.0,13.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.add.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_sub_splat_int
func.func @fold_aten_sub_splat_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<-15> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%int_2 = torch.constant.int 2
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.sub.Tensor %cst_7, %cst_11, %int_2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_sub_splat_float
func.func @fold_aten_sub_splat_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<-1.500000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%fp_2 = torch.constant.float 2.0
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.sub.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}
// -----
// CHECK-LABEL: @fold_aten_mul_splat_int
func.func @fold_aten_mul_splat_int() -> !torch.vtensor<[4],si64> {
// CHECK: torch.vtensor.literal(dense<77> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%0 = torch.aten.mul.Tensor %cst_7, %cst_11: !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64>
return %0 : !torch.vtensor<[4],si64>
}
// -----
// CHECK-LABEL: @fold_aten_mul_splat_float
func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> {
// CHECK: torch.vtensor.literal(dense<7.700000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}