Fixing implicit double to float casts. (#2476)

MSVC (and other compilers with implicit narrowing warnings) don't like
this type mismatch.
pull/2480/head snapshot-20230921.968
Ben Vanik 2023-09-20 10:48:40 -07:00 committed by GitHub
parent 023fc90072
commit b9847b1904
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 32 deletions

View File

@ -121,8 +121,8 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
} else {
assert(isInt);
return (intValue >= std::numeric_limits<T>::min()) &&
(intValue <= std::numeric_limits<T>::max());
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
}
return true;
}
@ -145,8 +145,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
"Unable to extract the scalar constant");
if (dtype.isa<mlir::FloatType>()) {
tosaTensor = tosa::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype)
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
(isFloat ? doubleValue : intValue),
dshape, dtype)
.value();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
auto w = intType.getWidth();
@ -162,7 +163,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
"of destination type");
}
bool d = isFloat ? static_cast<bool>(doubleValue)
: static_cast<bool>(intValue);
: static_cast<bool>(intValue);
tosaTensor =
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
} else if (w == 32) {
@ -616,7 +617,9 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
op, "Negative slope needs to be a scalar constant for conversion to "
"TOSA LeakyReLU operation");
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType()).value();
auto zero =
tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType())
.value();
auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(),
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
@ -2253,11 +2256,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
auto epsilonConst =
tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(eps)}, {},
meanType.getElementType())
.value();
auto epsilonConst = tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(eps)}, {},
meanType.getElementType())
.value();
auto batchNorm =
computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal,
@ -2571,7 +2573,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
// Constant value of ln2.
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056},
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056f},
ln2Shape, selfType.getElementType())
.value();
auto rcpOp =
@ -2802,21 +2804,25 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}, dtype).value();
auto a1 =
tosa::getConstTensor<float>(rewriter, op, 0.278393f, {}, dtype).value();
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}, dtype).value();
auto a2 =
tosa::getConstTensor<float>(rewriter, op, 0.230389f, {}, dtype).value();
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}, dtype).value();
auto a3 =
tosa::getConstTensor<float>(rewriter, op, 0.000972f, {}, dtype).value();
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}, dtype).value();
auto a4 =
tosa::getConstTensor<float>(rewriter, op, 0.078108f, {}, dtype).value();
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
@ -2851,13 +2857,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
// rsqrt of 2
Value rsqrt2 =
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}, dtype).value();
tosa::getConstTensor<float>(rewriter, op, 0.70710678f, {}, dtype).value();
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
/*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
Value oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
erfPlus1, /*shift=*/0);
@ -2891,7 +2898,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
cdf = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
op->getLoc(),
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
@ -2927,15 +2935,16 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
auto loc = op->getLoc();
const double cstAlpha0 = 1.12837916709551257390;
const double cstAlpha1 = 0.70710678118654752440;
const double oneHalf = 0.5;
const double kAlpha = cstAlpha0 * cstAlpha1;
const float cstAlpha0 = 1.12837916709551257390f;
const float cstAlpha1 = 0.70710678118654752440f;
const float oneHalf = 0.5f;
const float kAlpha = cstAlpha0 * cstAlpha1;
Value kAlphaHalf =
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value();
Value kAlphaHalf = tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf,
{}, selfElemTy)
.value();
Value negOneHalf =
tosa::getConstTensor<float>(rewriter, op, -0.5, {}, selfElemTy).value();
tosa::getConstTensor<float>(rewriter, op, -0.5f, {}, selfElemTy).value();
Value inputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
@ -3006,7 +3015,8 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
}
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
Value replace =
tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
Type outType = getTypeConverter()->convertType(op.getType());
Value lesser = rewriter.create<tosa::GreaterOp>(
@ -3553,7 +3563,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
// convert None to [0,0,0]
auto indexNext = indexTensors[i + 1];
auto indexNextTorch = tensorsTorchType[i + 1];
if (indexNextTorch.getType().isa<Torch::NoneType>()){
if (indexNextTorch.getType().isa<Torch::NoneType>()) {
return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now.");
}
@ -3620,12 +3630,13 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
indicesTfConcatTensors, lastDim);
if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp.
// do the tf scatterNd algorithm with tf style indices as input, algorithm
// mostly take from convertGatherNdOp.
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.getResult(), fillValues);
indicesTf.getResult(), fillValues);
if (!result) {
return rewriter.notifyMatchFailure(