Fixing implicit double to float casts. (#2476)

MSVC (and other compilers with implicit narrowing warnings) don't like
this type mismatch.
pull/2480/head snapshot-20230921.968
Ben Vanik 2023-09-20 10:48:40 -07:00 committed by GitHub
parent 023fc90072
commit b9847b1904
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 32 deletions

View File

@ -121,8 +121,8 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue))); return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
} else { } else {
assert(isInt); assert(isInt);
return (intValue >= std::numeric_limits<T>::min()) && return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
(intValue <= std::numeric_limits<T>::max()); (intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
} }
return true; return true;
} }
@ -145,8 +145,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
"Unable to extract the scalar constant"); "Unable to extract the scalar constant");
if (dtype.isa<mlir::FloatType>()) { if (dtype.isa<mlir::FloatType>()) {
tosaTensor = tosa::getConstTensor<float>( tosaTensor = tosa::getConstTensor<float>(rewriter, op,
rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) (isFloat ? doubleValue : intValue),
dshape, dtype)
.value(); .value();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) { } else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
auto w = intType.getWidth(); auto w = intType.getWidth();
@ -616,7 +617,9 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
op, "Negative slope needs to be a scalar constant for conversion to " op, "Negative slope needs to be a scalar constant for conversion to "
"TOSA LeakyReLU operation"); "TOSA LeakyReLU operation");
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType()).value(); auto zero =
tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType())
.value();
auto cond = rewriter.create<tosa::GreaterEqualOp>( auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
@ -2253,8 +2256,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
auto epsilonConst = auto epsilonConst = tosa::getConstTensor<float>(rewriter, op.getOperation(),
tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(eps)}, {}, {static_cast<float>(eps)}, {},
meanType.getElementType()) meanType.getElementType())
.value(); .value();
@ -2571,7 +2573,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
// Constant value of ln2. // Constant value of ln2.
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1); SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056f},
ln2Shape, selfType.getElementType()) ln2Shape, selfType.getElementType())
.value(); .value();
auto rcpOp = auto rcpOp =
@ -2802,21 +2804,25 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value(); auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value(); auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}, dtype).value(); auto a1 =
tosa::getConstTensor<float>(rewriter, op, 0.278393f, {}, dtype).value();
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0); auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one); auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}, dtype).value(); auto a2 =
tosa::getConstTensor<float>(rewriter, op, 0.230389f, {}, dtype).value();
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0); auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0); auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X); sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}, dtype).value(); auto a3 =
tosa::getConstTensor<float>(rewriter, op, 0.000972f, {}, dtype).value();
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0); auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0); auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X); sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}, dtype).value(); auto a4 =
tosa::getConstTensor<float>(rewriter, op, 0.078108f, {}, dtype).value();
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0); auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0); auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X); sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
@ -2851,13 +2857,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean); Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
// rsqrt of 2 // rsqrt of 2
Value rsqrt2 = Value rsqrt2 =
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}, dtype).value(); tosa::getConstTensor<float>(rewriter, op, 0.70710678f, {}, dtype).value();
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2, Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
/*shift=*/0); /*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf); Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value(); Value oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf, Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
erfPlus1, /*shift=*/0); erfPlus1, /*shift=*/0);
@ -2891,7 +2898,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
cdf = rewriter.createOrFold<tosa::CastOp>( cdf = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf); op->getLoc(),
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
rewriter.replaceOpWithNewOp<tosa::MulOp>( rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
@ -2927,15 +2935,16 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
auto loc = op->getLoc(); auto loc = op->getLoc();
const double cstAlpha0 = 1.12837916709551257390; const float cstAlpha0 = 1.12837916709551257390f;
const double cstAlpha1 = 0.70710678118654752440; const float cstAlpha1 = 0.70710678118654752440f;
const double oneHalf = 0.5; const float oneHalf = 0.5f;
const double kAlpha = cstAlpha0 * cstAlpha1; const float kAlpha = cstAlpha0 * cstAlpha1;
Value kAlphaHalf = Value kAlphaHalf = tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf,
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value(); {}, selfElemTy)
.value();
Value negOneHalf = Value negOneHalf =
tosa::getConstTensor<float>(rewriter, op, -0.5, {}, selfElemTy).value(); tosa::getConstTensor<float>(rewriter, op, -0.5f, {}, selfElemTy).value();
Value inputSquared = rewriter.create<tosa::MulOp>( Value inputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
Value negHalfInputSquared = rewriter.create<tosa::MulOp>( Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
@ -3006,7 +3015,8 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
} }
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value(); Value replace =
tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
Type outType = getTypeConverter()->convertType(op.getType()); Type outType = getTypeConverter()->convertType(op.getType());
Value lesser = rewriter.create<tosa::GreaterOp>( Value lesser = rewriter.create<tosa::GreaterOp>(
@ -3620,10 +3630,11 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
indicesTfConcatTensors, lastDim); indicesTfConcatTensors, lastDim);
if (!indicesTf) { if (!indicesTf) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op,
op, "Convert TorchIndex To TfIndices fail."); "Convert TorchIndex To TfIndices fail.");
} }
// do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp. // do the tf scatterNd algorithm with tf style indices as input, algorithm
// mostly take from convertGatherNdOp.
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.getResult(), fillValues); indicesTf.getResult(), fillValues);