From 6f420019cb2d6f5487f5201553d86f167a2856be Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Fri, 16 Jun 2023 09:51:24 +0200 Subject: [PATCH] TorchToTosa: Cast float constants to correct type to support bfloat16 (#2239) --- .../TorchToTosa/TosaLegalizeUtils.h | 3 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 76 +++++++++++-------- .../TorchToTosa/TosaLegalizeUtils.cpp | 27 +++++-- 3 files changed, 68 insertions(+), 38 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index c3ab1d474..14cf9cba7 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -55,7 +55,8 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // To create INT48 TOSA constant, need to pass in llvm::APInt instead. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape); + ArrayRef vec, ArrayRef shape, + std::optional dtype = {}); LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index fc1efa364..f5a961cfd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -145,7 +145,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, if (dtype.isa()) { tosaTensor = tosa::getConstTensor( - rewriter, op, (isFloat ? doubleValue : intValue), dshape) + rewriter, op, (isFloat ? doubleValue : intValue), dshape, dtype) .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); @@ -200,8 +200,9 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unsupported integer value for alpha"); - alphaTensor = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue); + alphaTensor = tosa::getConstTensor( + rewriter, op, {static_cast(alphaValue)}, {}, dtype) + .value(); return success(); } @@ -604,7 +605,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()).value(); auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -1006,17 +1007,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); + auto outType = + getTypeConverter()->convertType(op.getType()).template cast(); + Value expTensor; Value expScalar = op.getExponent(); if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, - selfTy.getElementType(), {}))) + outType.getElementType(), {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA Pow operation"); - auto outType = - getTypeConverter()->convertType(op.getType()).template cast(); - auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, self, expTensor); rewriter.replaceOp(op, powOp.getResult()); @@ -2159,7 +2160,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); + tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, + meanType.getElementType()) + .value(); auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, @@ -2263,7 +2267,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto elemCntConst = tosa::getConstTensor(rewriter, op.getOperation(), - {static_cast(elemCnt)}, {1}) + {static_cast(elemCnt)}, {1}, elemTy) .value(); Value elemCntRcp = rewriter.create( op.getLoc(), elemCntConst.getType(), elemCntConst); @@ -2318,7 +2322,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) return rewriter.notifyMatchFailure(op, "eps must be a scalar constant"); auto epsilonConst = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, eps); + tosa::getConstTensor(rewriter, op.getOperation(), + {static_cast(eps)}, {}, elemTy) + .value(); // Compute layer norm. auto layerNorm = @@ -2471,9 +2477,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); - auto ln2Op = - tosa::getConstTensor(rewriter, op, {0.69314718056}, ln2Shape) - .value(); + auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, + ln2Shape, selfType.getElementType()) + .value(); auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); @@ -2688,7 +2694,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } static Value approximateErfOp(ConversionPatternRewriter &rewriter, - Operation *op, Value x) { + Operation *op, Value x, Type dtype) { // Using: // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = @@ -2699,24 +2705,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto outType = x.getType().cast(); auto loc = op->getLoc(); auto absX = rewriter.create(loc, outType, x); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).value(); + auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}, dtype).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).value(); + auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).value(); + auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).value(); + auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2739,9 +2745,10 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, } static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, - Operation *op, Value x) { - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + Operation *op, Value x, Type dtype) { + auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); + auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one @@ -2750,12 +2757,14 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}).value(); + tosa::getConstTensor(rewriter, op, 0.70710678, {}, dtype).value(); + Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); - Value erf = approximateErfOp(rewriter, op, erfArg); + Value erf = approximateErfOp(rewriter, op, erfArg, dtype); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).value(); + Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); return normalCdf; @@ -2786,7 +2795,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf()); + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); + cdf = rewriter.createOrFold( + op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, /*shift=*/0); @@ -2827,16 +2839,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( const double kAlpha = cstAlpha0 * cstAlpha1; Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}).value(); + tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}, selfElemTy).value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}).value(); + tosa::getConstTensor(rewriter, op, -0.5, {}, selfElemTy).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( loc, selfType, inputSquared, negOneHalf, /*shift=*/0); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf()); + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); Value dinputInput = rewriter.create( loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0); Value dinputInputAlpha = rewriter.create( @@ -2900,7 +2912,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); } - Value replace = tosa::getConstTensor(rewriter, op, 0, {}).value(); + Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1d026a62d..0150ada19 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -174,7 +174,7 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // Default template creates a constant tensor in T. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape) { + ArrayRef vec, ArrayRef shape, std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -191,6 +191,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } return const_op.getResult(); } @@ -198,7 +203,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape) { + ArrayRef shape, std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -215,6 +220,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } return const_op.getResult(); } @@ -222,7 +232,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape) { + ArrayRef shape, std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -238,6 +248,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); + + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } return const_op.getResult(); } @@ -329,12 +344,14 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, - ArrayRef shape); + ArrayRef shape, + std::optional dtype); template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, - ArrayRef shape); + ArrayRef shape, + std::optional dtype); LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType) {