mirror of https://github.com/llvm/torch-mlir
[Torch] enhance naryFolderHelper to support mixed dtypes (#3559)
* so that it could support like `i64 + f64 => f64`. * also unify `aten.log`'s folder code to use `naryFolderHelper`.pull/3569/head
parent
aad1604046
commit
003b06dfa1
|
@ -1224,30 +1224,6 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
// NAry folder helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
|
||||
bool allFp = true;
|
||||
bool allInt = true;
|
||||
|
||||
for (auto attr : attrs) {
|
||||
if (!attr)
|
||||
return false;
|
||||
|
||||
Type attrty;
|
||||
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr))
|
||||
attrty = dense.getType();
|
||||
if (auto fp = dyn_cast_or_null<mlir::FloatAttr>(attr))
|
||||
attrty = fp.getType();
|
||||
if (auto integer = dyn_cast_or_null<mlir::IntegerAttr>(attr))
|
||||
attrty = integer.getType();
|
||||
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
|
||||
attrty = shaped.getElementType();
|
||||
allFp &= isa<mlir::FloatType>(attrty);
|
||||
allInt &= isa<mlir::IntegerType>(attrty);
|
||||
}
|
||||
|
||||
return allFp || allInt;
|
||||
}
|
||||
|
||||
static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
|
||||
for (auto attr : attrs) {
|
||||
if (auto dense = dyn_cast_or_null<ElementsAttr>(attr)) {
|
||||
|
@ -1263,15 +1239,38 @@ llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
|
|||
int64_t idx = 0) {
|
||||
llvm::SmallVector<double> splattrs;
|
||||
|
||||
// Note that i1 is neither signed nor unsigned.
|
||||
// But we should trait i1 as unsigned, otherwise that
|
||||
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
||||
// So here only distinguish signed integer.
|
||||
auto convertAPIntToDouble = [](APInt value, bool isSigned) -> double {
|
||||
if (isSigned)
|
||||
return static_cast<double>(value.getSExtValue());
|
||||
else
|
||||
return static_cast<double>(value.getZExtValue());
|
||||
};
|
||||
|
||||
for (auto attr : attrs) {
|
||||
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
|
||||
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
|
||||
if (dense.isSplat()) {
|
||||
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
|
||||
} else {
|
||||
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
|
||||
}
|
||||
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
|
||||
splattrs.push_back(intattr.getValueAsDouble());
|
||||
} else if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
|
||||
bool isSigned = cast<IntegerType>(dense.getElementType()).isSigned();
|
||||
if (dense.isSplat()) {
|
||||
splattrs.push_back(
|
||||
convertAPIntToDouble(dense.getSplatValue<APInt>(), isSigned));
|
||||
} else {
|
||||
splattrs.push_back(
|
||||
convertAPIntToDouble(dense.getValues<APInt>()[idx], isSigned));
|
||||
}
|
||||
} else if (auto fpattr = dyn_cast<FloatAttr>(attr)) {
|
||||
splattrs.push_back(fpattr.getValueAsDouble());
|
||||
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
|
||||
bool isSigned = cast<IntegerType>(intattr.getType()).isSigned();
|
||||
splattrs.push_back(convertAPIntToDouble(intattr.getValue(), isSigned));
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
|
@ -1286,13 +1285,9 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
|
|||
llvm::SmallVector<APInt> splattrs;
|
||||
|
||||
for (auto attr : attrs) {
|
||||
// Note that i1 is neither signed nor unsigned.
|
||||
// But we should trait i1 as unsigned, otherwise that
|
||||
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
||||
// So here only distinguish signed integer.
|
||||
bool isSigned = false;
|
||||
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
|
||||
isSigned = dyn_cast<IntegerType>(dense.getElementType()).isSigned();
|
||||
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
|
||||
isSigned = cast<IntegerType>(dense.getElementType()).isSigned();
|
||||
if (dense.isSplat()) {
|
||||
splattrs.push_back(dense.getSplatValue<APInt>());
|
||||
} else {
|
||||
|
@ -1305,6 +1300,10 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
|
|||
return {};
|
||||
}
|
||||
|
||||
// Note that i1 is neither signed nor unsigned.
|
||||
// But we should trait i1 as unsigned, otherwise that
|
||||
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
||||
// So here only distinguish signed integer.
|
||||
auto &apint = splattrs.back();
|
||||
if (apint.getBitWidth() < bitwidth) {
|
||||
if (isSigned) {
|
||||
|
@ -1324,12 +1323,14 @@ using NAryFoldIntOperator = std::function<APInt(ArrayRef<APInt>)>;
|
|||
static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
||||
NAryFoldFpOperator fpFolder,
|
||||
NAryFoldIntOperator intFolder) {
|
||||
constexpr int64_t maxFold = 16;
|
||||
if (!checkSameDTypes(operands))
|
||||
return nullptr;
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
for (auto attr : operands) {
|
||||
if (!attr)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto resultTy = dyn_cast<ValueTensorType>(ty);
|
||||
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes())
|
||||
if (!resultTy || !resultTy.hasDtype() || !resultTy.areAllSizesKnown())
|
||||
return nullptr;
|
||||
|
||||
auto dty = resultTy.getDtype();
|
||||
|
@ -1341,10 +1342,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
|||
return nullptr;
|
||||
|
||||
bool allSplats = checkAllSplats(operands);
|
||||
bool withinMaxFold =
|
||||
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold;
|
||||
|
||||
if (!allSplats && !withinMaxFold)
|
||||
if (!(allSplats || resultBTy.getNumElements() <= kMaxFold))
|
||||
return nullptr;
|
||||
|
||||
// We do not support broadcasting in the non-splat case so validate same
|
||||
|
@ -1371,6 +1369,8 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
|||
llvm::SmallVector<APFloat> folded;
|
||||
for (int i = 0, s = numValues; i < s; ++i) {
|
||||
auto inputs = getFoldValueAtIndexFp(operands, i);
|
||||
if (inputs.size() != operands.size())
|
||||
return nullptr;
|
||||
double fold = fpFolder(inputs);
|
||||
|
||||
APFloat val(fold);
|
||||
|
@ -1387,6 +1387,8 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
|||
for (int i = 0, s = numValues; i < s; ++i) {
|
||||
auto inputs =
|
||||
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
|
||||
if (inputs.size() != operands.size())
|
||||
return nullptr;
|
||||
folded.push_back(intFolder(inputs));
|
||||
}
|
||||
return DenseElementsAttr::get(resultBTy, folded);
|
||||
|
@ -1649,13 +1651,9 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|||
constexpr int64_t kMaxFold = 16;
|
||||
if (!lhs || !rhs || !resultTy)
|
||||
return nullptr;
|
||||
if (!resultTy.hasSizes() || !resultTy.hasDtype())
|
||||
if (!resultTy.areAllSizesKnown() || !resultTy.hasDtype())
|
||||
return nullptr;
|
||||
|
||||
for (auto size : resultTy.getSizes())
|
||||
if (size == Torch::kUnknownSize)
|
||||
return nullptr;
|
||||
|
||||
auto ctx = lhs.getContext();
|
||||
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||
if (lhs.isSplat()) {
|
||||
|
@ -1843,75 +1841,21 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
|
|||
// AtenLogOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using UnaryPromoteFpOperator = std::function<double(double)>;
|
||||
using UnaryPromoteIntOperator = std::function<double(APInt, bool)>;
|
||||
|
||||
static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
|
||||
ValueTensorType resultTy,
|
||||
UnaryPromoteFpOperator fpFolder,
|
||||
UnaryPromoteIntOperator intFolder) {
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (!resultTy.hasDtype() || !resultTy.hasSizes())
|
||||
return nullptr;
|
||||
if (!isa<mlir::FloatType>(resultTy.getDtype()))
|
||||
return nullptr;
|
||||
|
||||
auto fpTy = dyn_cast<mlir::FloatType>(operand.getType().getElementType());
|
||||
auto intTy = dyn_cast<mlir::IntegerType>(operand.getType().getElementType());
|
||||
if (!fpTy && !intTy)
|
||||
return nullptr;
|
||||
|
||||
auto resultBTy = resultTy.toBuiltinTensor();
|
||||
bool splat = operand.isSplat();
|
||||
bool withinMaxFold =
|
||||
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
|
||||
if (!splat && !withinMaxFold)
|
||||
return nullptr;
|
||||
|
||||
const int64_t numValues = splat ? 1 : resultBTy.getNumElements();
|
||||
|
||||
llvm::SmallVector<Attribute> operands = {operand};
|
||||
llvm::SmallVector<APFloat> folded;
|
||||
for (int i = 0, s = numValues; i < s; ++i) {
|
||||
double fold = 0.0;
|
||||
if (fpTy) {
|
||||
auto inputs = getFoldValueAtIndexFp(operands, i);
|
||||
fold = fpFolder(inputs[0]);
|
||||
}
|
||||
if (intTy) {
|
||||
auto inputs =
|
||||
getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i);
|
||||
fold = intFolder(inputs[0], intTy.isSigned());
|
||||
}
|
||||
|
||||
APFloat val(fold);
|
||||
bool unused;
|
||||
val.convert(
|
||||
cast<mlir::FloatType>(resultBTy.getElementType()).getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &unused);
|
||||
folded.push_back(val);
|
||||
}
|
||||
return DenseElementsAttr::get(resultBTy, folded);
|
||||
}
|
||||
|
||||
OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
|
||||
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||
if (!self || !resultType)
|
||||
return nullptr;
|
||||
|
||||
// Note that i1 is neither signed nor unsigned.
|
||||
// But we should trait i1 as unsigned, otherwise that
|
||||
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
|
||||
auto intFold = [](APInt a, bool isSigned) -> double {
|
||||
if (isSigned)
|
||||
return std::log(a.getSExtValue());
|
||||
else
|
||||
return std::log(a.getZExtValue());
|
||||
auto fpFold = [](llvm::ArrayRef<double> inputs) -> double {
|
||||
assert(inputs.size() == 1);
|
||||
return std::log(inputs[0]);
|
||||
};
|
||||
auto intFold = [](llvm::ArrayRef<APInt> inputs) -> APInt {
|
||||
assert(false && "should not reach here");
|
||||
};
|
||||
auto fpFold = [](double a) -> double { return std::log(a); };
|
||||
|
||||
return unaryPromoteFolder(self, resultType, fpFold, intFold);
|
||||
return naryFolderHelper(adaptor.getOperands(), resultType, fpFold, intFold);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1964,6 +1964,17 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],
|
|||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sub.Tensor$mixed_dtype() -> !torch.vtensor<[],f64> {
|
||||
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<2.750000e+01> : tensor<f64>) : !torch.vtensor<[],f64>
|
||||
// CEHCK: return %[[CST]]
|
||||
func.func @torch.aten.sub.Tensor$mixed_dtype() -> !torch.vtensor<[],f64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.vtensor.literal(dense<28> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%1 = torch.vtensor.literal(dense<5.000000e-01> : tensor<f64>) : !torch.vtensor<[],f64>
|
||||
%2 = torch.aten.sub.Tensor %0, %1, %int1 : !torch.vtensor<[],si64>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64>
|
||||
return %2 : !torch.vtensor<[],f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
|
||||
|
|
Loading…
Reference in New Issue