[Torch] enhance naryFolderHelper to support mixed dtypes (#3559)

* so that it could support like `i64 + f64 => f64`.
* also unify `aten.log`'s folder code to use `naryFolderHelper`.
pull/3569/head
Yuanqiang Liu 2024-07-24 17:54:59 +08:00 committed by GitHub
parent aad1604046
commit 003b06dfa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 107 deletions

View File

@ -1224,30 +1224,6 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
// NAry folder helpers
//===----------------------------------------------------------------------===//
static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
bool allFp = true;
bool allInt = true;
for (auto attr : attrs) {
if (!attr)
return false;
Type attrty;
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr))
attrty = dense.getType();
if (auto fp = dyn_cast_or_null<mlir::FloatAttr>(attr))
attrty = fp.getType();
if (auto integer = dyn_cast_or_null<mlir::IntegerAttr>(attr))
attrty = integer.getType();
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
attrty = shaped.getElementType();
allFp &= isa<mlir::FloatType>(attrty);
allInt &= isa<mlir::IntegerType>(attrty);
}
return allFp || allInt;
}
static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
for (auto attr : attrs) {
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
@ -1263,15 +1239,38 @@ llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
int64_t idx = 0) {
llvm::SmallVector<double> splattrs;
// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
// So here only distinguish signed integer.
auto convertAPIntToDouble = [](APInt value, bool isSigned) -> double {
if (isSigned)
return static_cast<double>(value.getSExtValue());
else
return static_cast<double>(value.getZExtValue());
};
for (auto attr : attrs) {
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
} else {
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
}
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
splattrs.push_back(intattr.getValueAsDouble());
} else if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
bool isSigned = cast<IntegerType>(dense.getElementType()).isSigned();
if (dense.isSplat()) {
splattrs.push_back(
convertAPIntToDouble(dense.getSplatValue<APInt>(), isSigned));
} else {
splattrs.push_back(
convertAPIntToDouble(dense.getValues<APInt>()[idx], isSigned));
}
} else if (auto fpattr = dyn_cast<FloatAttr>(attr)) {
splattrs.push_back(fpattr.getValueAsDouble());
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
bool isSigned = cast<IntegerType>(intattr.getType()).isSigned();
splattrs.push_back(convertAPIntToDouble(intattr.getValue(), isSigned));
} else {
return {};
}
@ -1286,13 +1285,9 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
llvm::SmallVector<APInt> splattrs;
for (auto attr : attrs) {
// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
// So here only distinguish signed integer.
bool isSigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
isSigned = dyn_cast<IntegerType>(dense.getElementType()).isSigned();
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
isSigned = cast<IntegerType>(dense.getElementType()).isSigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
} else {
@ -1305,6 +1300,10 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
return {};
}
// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
// So here only distinguish signed integer.
auto &apint = splattrs.back();
if (apint.getBitWidth() < bitwidth) {
if (isSigned) {
@ -1324,12 +1323,14 @@ using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;
static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
NAryFoldFpOperator fpFolder,
NAryFoldIntOperator intFolder) {
constexpr int64_t maxFold = 16;
if (!checkSameDTypes(operands))
return nullptr;
constexpr int64_t kMaxFold = 16;
for (auto attr : operands) {
if (!attr)
return nullptr;
}
auto resultTy = dyn_cast<ValueTensorType>(ty);
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
if (!resultTy || !resultTy.hasDtype() || !resultTy.areAllSizesKnown())
return nullptr;
auto dty = resultTy.getDtype();
@ -1341,10 +1342,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
return nullptr;
bool allSplats = checkAllSplats(operands);
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold;
if (!allSplats && !withinMaxFold)
if (!(allSplats || resultBTy.getNumElements() <= kMaxFold))
return nullptr;
// We do not support broadcasting in the non-splat case so validate same
@ -1371,6 +1369,8 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs = getFoldValueAtIndexFp(operands, i);
if (inputs.size() != operands.size())
return nullptr;
double fold = fpFolder(inputs);
APFloat val(fold);
@ -1387,6 +1387,8 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs =
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
if (inputs.size() != operands.size())
return nullptr;
folded.push_back(intFolder(inputs));
}
return DenseElementsAttr::get(resultBTy, folded);
@ -1649,13 +1651,9 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
constexpr int64_t kMaxFold = 16;
if (!lhs || !rhs || !resultTy)
return nullptr;
if (!resultTy.hasSizes() || !resultTy.hasDtype())
if (!resultTy.areAllSizesKnown() || !resultTy.hasDtype())
return nullptr;
for (auto size : resultTy.getSizes())
if (size == Torch::kUnknownSize)
return nullptr;
auto ctx = lhs.getContext();
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
if (lhs.isSplat()) {
@ -1843,75 +1841,21 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
// AtenLogOp
//===----------------------------------------------------------------------===//
using UnaryPromoteFpOperator = std::function<double(double)>;
using UnaryPromoteIntOperator = std::function<double(APInt, bool)>;
static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
ValueTensorType resultTy,
UnaryPromoteFpOperator fpFolder,
UnaryPromoteIntOperator intFolder) {
constexpr int64_t kMaxFold = 16;
if (!resultTy.hasDtype() || !resultTy.hasSizes())
return nullptr;
if (!isa<mlir::FloatType>(resultTy.getDtype()))
return nullptr;
auto fpTy = dyn_cast<mlir::FloatType>(operand.getType().getElementType());
auto intTy = dyn_cast<mlir::IntegerType>(operand.getType().getElementType());
if (!fpTy && !intTy)
return nullptr;
auto resultBTy = resultTy.toBuiltinTensor();
bool splat = operand.isSplat();
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
if (!splat && !withinMaxFold)
return nullptr;
const int64_t numValues = splat ? 1 : resultBTy.getNumElements();
llvm::SmallVector<Attribute> operands = {operand};
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
double fold = 0.0;
if (fpTy) {
auto inputs = getFoldValueAtIndexFp(operands, i);
fold = fpFolder(inputs[0]);
}
if (intTy) {
auto inputs =
getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i);
fold = intFolder(inputs[0], intTy.isSigned());
}
APFloat val(fold);
bool unused;
val.convert(
cast<mlir::FloatType>(resultBTy.getElementType()).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
folded.push_back(val);
}
return DenseElementsAttr::get(resultBTy, folded);
}
OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
auto resultType = dyn_cast<ValueTensorType>(getType());
if (!self || !resultType)
return nullptr;
// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
auto intFold = [](APInt a, bool isSigned) -> double {
if (isSigned)
return std::log(a.getSExtValue());
else
return std::log(a.getZExtValue());
auto fpFold = [](llvm::ArrayRef<double> inputs) -> double {
assert(inputs.size() == 1);
return std::log(inputs[0]);
};
auto intFold = [](llvm::ArrayRef<APInt> inputs) -> APInt {
assert(false && "should not reach here");
};
auto fpFold = [](double a) -> double { return std::log(a); };
return unaryPromoteFolder(self, resultType, fpFold, intFold);
return naryFolderHelper(adaptor.getOperands(), resultType, fpFold, intFold);
}
//===----------------------------------------------------------------------===//

View File

@ -1964,6 +1964,17 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: func.func @torch.aten.sub.Tensor$mixed_dtype() -> !torch.vtensor<[],f64> {
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<2.750000e+01> : tensor<f64>) : !torch.vtensor<[],f64>
// CEHCK: return %[[CST]]
func.func @torch.aten.sub.Tensor$mixed_dtype() -> !torch.vtensor<[],f64> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<28> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.vtensor.literal(dense<5.000000e-01> : tensor<f64>) : !torch.vtensor<[],f64>
%2 = torch.aten.sub.Tensor %0, %1, %int1 : !torch.vtensor<[],si64>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
return %2 : !torch.vtensor<[],f64>
}
// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>