mirror of https://github.com/llvm/torch-mlir
TorchToTosa: Cast float constants to correct type to support bfloat16 (#2239)
parent
45e2188615
commit
6f420019cb
|
@ -55,7 +55,8 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
|||
// To create INT48 TOSA constant, need to pass in llvm::APInt instead.
|
||||
template <typename T>
|
||||
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape,
|
||||
std::optional<Type> dtype = {});
|
||||
|
||||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||
Value src, Type destType, Value &result);
|
||||
|
|
|
@ -145,7 +145,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
tosaTensor = tosa::getConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype)
|
||||
.value();
|
||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
auto w = intType.getWidth();
|
||||
|
@ -200,8 +200,9 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Unsupported integer value for alpha");
|
||||
|
||||
alphaTensor =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue);
|
||||
alphaTensor = tosa::getConstTensor<float>(
|
||||
rewriter, op, {static_cast<float>(alphaValue)}, {}, dtype)
|
||||
.value();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -604,7 +605,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
|
|||
op, "Negative slope needs to be a scalar constant for conversion to "
|
||||
"TOSA LeakyReLU operation");
|
||||
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfTy.getElementType()).value();
|
||||
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
|
||||
|
@ -1006,17 +1007,17 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point datatype legalization supported");
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||
|
||||
Value expTensor;
|
||||
Value expScalar = op.getExponent();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
||||
selfTy.getElementType(), {})))
|
||||
outType.getElementType(), {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||
|
||||
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
|
||||
self, expTensor);
|
||||
rewriter.replaceOp(op, powOp.getResult());
|
||||
|
@ -2159,7 +2160,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
|
||||
|
||||
auto epsilonConst =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
|
||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||
{static_cast<float>(eps)}, {},
|
||||
meanType.getElementType())
|
||||
.value();
|
||||
|
||||
auto batchNorm =
|
||||
computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal,
|
||||
|
@ -2263,7 +2267,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
|
||||
auto elemCntConst =
|
||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||
{static_cast<float>(elemCnt)}, {1})
|
||||
{static_cast<float>(elemCnt)}, {1}, elemTy)
|
||||
.value();
|
||||
Value elemCntRcp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op.getLoc(), elemCntConst.getType(), elemCntConst);
|
||||
|
@ -2318,7 +2322,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
|
||||
return rewriter.notifyMatchFailure(op, "eps must be a scalar constant");
|
||||
auto epsilonConst =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps);
|
||||
tosa::getConstTensor<float>(rewriter, op.getOperation(),
|
||||
{static_cast<float>(eps)}, {}, elemTy)
|
||||
.value();
|
||||
|
||||
// Compute layer norm.
|
||||
auto layerNorm =
|
||||
|
@ -2471,8 +2477,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|||
|
||||
// Constant value of ln2.
|
||||
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
||||
auto ln2Op =
|
||||
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
|
||||
auto ln2Op = tosa::getConstTensor<float>(rewriter, op, {0.69314718056},
|
||||
ln2Shape, selfType.getElementType())
|
||||
.value();
|
||||
auto rcpOp =
|
||||
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
|
||||
|
@ -2688,7 +2694,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value x) {
|
||||
Operation *op, Value x, Type dtype) {
|
||||
// Using:
|
||||
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
|
||||
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
|
||||
|
@ -2699,24 +2705,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
|||
auto outType = x.getType().cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
|
||||
|
||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).value();
|
||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}, dtype).value();
|
||||
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
||||
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
||||
|
||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).value();
|
||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}, dtype).value();
|
||||
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
||||
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
||||
|
||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).value();
|
||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}, dtype).value();
|
||||
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
||||
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
||||
|
||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).value();
|
||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}, dtype).value();
|
||||
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
||||
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
||||
|
@ -2739,9 +2745,10 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value x) {
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).value();
|
||||
Operation *op, Value x, Type dtype) {
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// buildNormalCdf, mean = zero, sigma = one
|
||||
|
@ -2750,12 +2757,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
|||
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
||||
// rsqrt of 2
|
||||
Value rsqrt2 =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).value();
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}, dtype).value();
|
||||
|
||||
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
||||
/*shift=*/0);
|
||||
Value erf = approximateErfOp(rewriter, op, erfArg);
|
||||
Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
|
||||
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).value();
|
||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
|
||||
|
||||
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
||||
erfPlus1, /*shift=*/0);
|
||||
return normalCdf;
|
||||
|
@ -2786,7 +2795,10 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
||||
}
|
||||
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf());
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
|
||||
cdf = rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(), cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
|
||||
/*shift=*/0);
|
||||
|
@ -2827,16 +2839,16 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
const double kAlpha = cstAlpha0 * cstAlpha1;
|
||||
|
||||
Value kAlphaHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}).value();
|
||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value();
|
||||
Value negOneHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).value();
|
||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}, selfElemTy).value();
|
||||
Value inputSquared = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
|
||||
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, inputSquared, negOneHalf, /*shift=*/0);
|
||||
Value dinput =
|
||||
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf());
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
|
||||
Value dinputInput = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0);
|
||||
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
|
||||
|
@ -2900,7 +2912,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
|
||||
}
|
||||
|
||||
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}, selfElemTy).value();
|
||||
Type outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Value lesser = rewriter.create<tosa::GreaterOp>(
|
||||
|
|
|
@ -174,7 +174,7 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
|||
// Default template creates a constant tensor in T.
|
||||
template <typename T>
|
||||
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape, std::optional<Type> dtype) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
|
@ -191,6 +191,11 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
|
||||
if (dtype) {
|
||||
return rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
||||
}
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -198,7 +203,7 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
|||
template <>
|
||||
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<APInt> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
|
@ -215,6 +220,11 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
|||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
|
||||
if (dtype) {
|
||||
return rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
||||
}
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -222,7 +232,7 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
|||
template <>
|
||||
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<float> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
|
@ -238,6 +248,11 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
|
||||
auto const_op =
|
||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||
|
||||
if (dtype) {
|
||||
return rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
||||
}
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -329,12 +344,14 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
|||
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
ArrayRef<int32_t> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
ArrayRef<int64_t> shape,
|
||||
std::optional<Type> dtype);
|
||||
|
||||
template std::optional<Value> getConstTensor<int64_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
ArrayRef<int64_t> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
ArrayRef<int64_t> shape,
|
||||
std::optional<Type> dtype);
|
||||
|
||||
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
|
||||
TypeAttr &accType) {
|
||||
|
|
Loading…
Reference in New Issue