mirror of https://github.com/llvm/torch-mlir
Fixing implicit double to float casts. (#2476)
MSVC (and other compilers with implicit narrowing warnings) don't like this type mismatch.pull/2480/head snapshot-20230921.968
parent
023fc90072
commit
b9847b1904
|
@ -121,8 +121,8 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||||
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
|
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
|
||||||
} else {
|
} else {
|
||||||
assert(isInt);
|
assert(isInt);
|
||||||
return (intValue >= std::numeric_limits<T>::min()) &&
|
return (intValue >= static_cast<int64_t>(std::numeric_limits<T>::min())) &&
|
||||||
(intValue <= std::numeric_limits<T>::max());
|
(intValue <= static_cast<int64_t>(std::numeric_limits<T>::max()));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -145,8 +145,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
"Unable to extract the scalar constant");
|
"Unable to extract the scalar constant");
|
||||||
|
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
tosaTensor = tosa::getConstTensor<float>(
|
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
||||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype)
|
(isFloat ? doubleValue : intValue),
|
||||||
|
dshape, dtype)
|
||||||
.value();
|
.value();
|
||||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||||
auto w = intType.getWidth();
|
auto w = intType.getWidth();
|
||||||
|
@ -616,7 +617,9 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
|
||||||
op, "Negative slope needs to be a scalar constant for conversion to "
|
op, "Negative slope needs to be a scalar constant for conversion to "
|
||||||
"TOSA LeakyReLU operation");
|
"TOSA LeakyReLU operation");
|
||||||
|
|
||||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType()).value();
|
auto zero =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType())
|
||||||
|
.value();
|
||||||
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
|
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
|
||||||
|
@ -2253,8 +2256,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
|
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
|
||||||
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
|
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
|
||||||
|
|
||||||
auto epsilonConst =
|
auto epsilonConst = tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
|
||||||
{static_cast<float>(eps)}, {},
|
{static_cast<float>(eps)}, {},
|
||||||
meanType.getElementType())
|
meanType.getElementType())
|
||||||
.value();
|
.value();
|
||||||
|
@ -2571,7 +2573,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
|
|
||||||
// Constant value of ln2.
|
// Constant value of ln2.
|
||||||
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
||||||
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056},
|
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056f},
|
||||||
ln2Shape, selfType.getElementType())
|
ln2Shape, selfType.getElementType())
|
||||||
.value();
|
.value();
|
||||||
auto rcpOp =
|
auto rcpOp =
|
||||||
|
@ -2802,21 +2804,25 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
||||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
||||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
|
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
|
||||||
|
|
||||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}, dtype).value();
|
auto a1 =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.278393f, {}, dtype).value();
|
||||||
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
||||||
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
||||||
|
|
||||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}, dtype).value();
|
auto a2 =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.230389f, {}, dtype).value();
|
||||||
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
||||||
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
||||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
||||||
|
|
||||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}, dtype).value();
|
auto a3 =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.000972f, {}, dtype).value();
|
||||||
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
||||||
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
||||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
||||||
|
|
||||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}, dtype).value();
|
auto a4 =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.078108f, {}, dtype).value();
|
||||||
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
||||||
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
||||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
||||||
|
@ -2851,13 +2857,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
||||||
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
||||||
// rsqrt of 2
|
// rsqrt of 2
|
||||||
Value rsqrt2 =
|
Value rsqrt2 =
|
||||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}, dtype).value();
|
tosa::getConstTensor<float>(rewriter, op, 0.70710678f, {}, dtype).value();
|
||||||
|
|
||||||
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
||||||
/*shift=*/0);
|
/*shift=*/0);
|
||||||
Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
|
Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
|
||||||
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
||||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
|
Value oneHalf =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
|
||||||
|
|
||||||
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
||||||
erfPlus1, /*shift=*/0);
|
erfPlus1, /*shift=*/0);
|
||||||
|
@ -2891,7 +2898,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
|
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
|
||||||
cdf = rewriter.createOrFold<tosa::CastOp>(
|
cdf = rewriter.createOrFold<tosa::CastOp>(
|
||||||
op->getLoc(), cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
|
op->getLoc(),
|
||||||
|
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
|
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
|
||||||
|
@ -2927,15 +2935,16 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
const double cstAlpha0 = 1.12837916709551257390;
|
const float cstAlpha0 = 1.12837916709551257390f;
|
||||||
const double cstAlpha1 = 0.70710678118654752440;
|
const float cstAlpha1 = 0.70710678118654752440f;
|
||||||
const double oneHalf = 0.5;
|
const float oneHalf = 0.5f;
|
||||||
const double kAlpha = cstAlpha0 * cstAlpha1;
|
const float kAlpha = cstAlpha0 * cstAlpha1;
|
||||||
|
|
||||||
Value kAlphaHalf =
|
Value kAlphaHalf = tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf,
|
||||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value();
|
{}, selfElemTy)
|
||||||
|
.value();
|
||||||
Value negOneHalf =
|
Value negOneHalf =
|
||||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}, selfElemTy).value();
|
tosa::getConstTensor<float>(rewriter, op, -0.5f, {}, selfElemTy).value();
|
||||||
Value inputSquared = rewriter.create<tosa::MulOp>(
|
Value inputSquared = rewriter.create<tosa::MulOp>(
|
||||||
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
|
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
|
||||||
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
||||||
|
@ -3006,7 +3015,8 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
|
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
|
Value replace =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
|
||||||
Type outType = getTypeConverter()->convertType(op.getType());
|
Type outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
Value lesser = rewriter.create<tosa::GreaterOp>(
|
Value lesser = rewriter.create<tosa::GreaterOp>(
|
||||||
|
@ -3620,10 +3630,11 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||||
indicesTfConcatTensors, lastDim);
|
indicesTfConcatTensors, lastDim);
|
||||||
|
|
||||||
if (!indicesTf) {
|
if (!indicesTf) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(op,
|
||||||
op, "Convert TorchIndex To TfIndices fail.");
|
"Convert TorchIndex To TfIndices fail.");
|
||||||
}
|
}
|
||||||
// do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp.
|
// do the tf scatterNd algorithm with tf style indices as input, algorithm
|
||||||
|
// mostly take from convertGatherNdOp.
|
||||||
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
||||||
indicesTf.getResult(), fillValues);
|
indicesTf.getResult(), fillValues);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue