TorchToTosa: Cast float constants to correct type to support bfloat16 (#2239)

pull/2241/head snapshot-20230616.871
Matthias Gehre 2023-06-16 09:51:24 +02:00 committed by GitHub
parent 45e2188615
commit 6f420019cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 38 deletions

View File

@ -55,7 +55,8 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
// To create INT48 TOSA constant, need to pass in llvm::APInt instead.
template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape);
ArrayRef<T> vec, ArrayRef<int64_t> shape,
std::optional<Type> dtype = {});
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result);

View File

@ -145,7 +145,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
if (dtype.isa<mlir::FloatType>()) {
tosaTensor = tosa::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype)
.value();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
auto w = intType.getWidth();
@ -200,8 +200,9 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
return rewriter.notifyMatchFailure(op,
"Unsupported integer value for alpha");
alphaTensor =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue);
alphaTensor = tosa::getConstTensor<float>(
rewriter, op, {static_cast<float>(alphaValue)}, {}, dtype)
.value();
return success();
}
@ -604,7 +605,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
op, "Negative slope needs to be a scalar constant for conversion to "
"TOSA LeakyReLU operation");
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType()).value();
auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(),
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
@ -1006,17 +1007,17 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
Value expTensor;
Value expScalar = op.getExponent();
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
selfTy.getElementType(), {})))
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
self, expTensor);
rewriter.replaceOp(op, powOp.getResult());
@ -2159,7 +2160,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
auto epsilonConst =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(eps)}, {},
meanType.getElementType())
.value();
auto batchNorm =
computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal,
@ -2263,7 +2267,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto elemCntConst =
tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(elemCnt)}, {1})
{static_cast<float>(elemCnt)}, {1}, elemTy)
.value();
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), elemCntConst.getType(), elemCntConst);
@ -2318,7 +2322,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
auto epsilonConst =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
tosa::getConstTensor<float>(rewriter, op.getOperation(),
{static_cast<float>(eps)}, {}, elemTy)
.value();
// Compute layer norm.
auto layerNorm =
@ -2471,8 +2477,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
// Constant value of ln2.
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op =
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056},
ln2Shape, selfType.getElementType())
.value();
auto rcpOp =
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
@ -2688,7 +2694,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
}
static Value approximateErfOp(ConversionPatternRewriter &rewriter,
Operation *op, Value x) {
Operation *op, Value x, Type dtype) {
// Using:
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
@ -2699,24 +2705,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
auto outType = x.getType().cast<TensorType>();
auto loc = op->getLoc();
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).value();
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}, dtype).value();
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).value();
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}, dtype).value();
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).value();
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}, dtype).value();
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).value();
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}, dtype).value();
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
@ -2739,9 +2745,10 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
}
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Operation *op, Value x) {
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
Operation *op, Value x, Type dtype) {
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
auto loc = op->getLoc();
// buildNormalCdf, mean = zero, sigma = one
@ -2750,12 +2757,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
// rsqrt of 2
Value rsqrt2 =
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).value();
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}, dtype).value();
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
/*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg);
Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).value();
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
erfPlus1, /*shift=*/0);
return normalCdf;
@ -2786,7 +2795,10 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
}
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf());
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
cdf = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
/*shift=*/0);
@ -2827,16 +2839,16 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
const double kAlpha = cstAlpha0 * cstAlpha1;
Value kAlphaHalf =
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}).value();
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value();
Value negOneHalf =
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).value();
tosa::getConstTensor<float>(rewriter, op, -0.5, {}, selfElemTy).value();
Value inputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, inputSquared, negOneHalf, /*shift=*/0);
Value dinput =
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf());
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
Value dinputInput = rewriter.create<tosa::MulOp>(
loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0);
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
@ -2900,7 +2912,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
}
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
Type outType = getTypeConverter()->convertType(op.getType());
Value lesser = rewriter.create<tosa::GreaterOp>(

View File

@ -174,7 +174,7 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
// Default template creates a constant tensor in T.
template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
ArrayRef<T> vec, ArrayRef<int64_t> shape, std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -191,6 +191,11 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
if (dtype) {
return rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
}
return const_op.getResult();
}
@ -198,7 +203,7 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
template <>
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
Operation *op, ArrayRef<APInt> vec,
ArrayRef<int64_t> shape) {
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -215,6 +220,11 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
if (dtype) {
return rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
}
return const_op.getResult();
}
@ -222,7 +232,7 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
template <>
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
Operation *op, ArrayRef<float> vec,
ArrayRef<int64_t> shape) {
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -238,6 +248,11 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
if (dtype) {
return rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
}
return const_op.getResult();
}
@ -329,12 +344,14 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,
ArrayRef<int32_t> vec,
ArrayRef<int64_t> shape);
ArrayRef<int64_t> shape,
std::optional<Type> dtype);
template std::optional<Value> getConstTensor<int64_t>(PatternRewriter &,
Operation *,
ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape);
ArrayRef<int64_t> shape,
std::optional<Type> dtype);
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
TypeAttr &accType) {