From ccf924d3df2061c707f1a758e5f8c5682d15aa5d Mon Sep 17 00:00:00 2001 From: Anup Gangwar Date: Wed, 30 Mar 2022 19:00:55 -0500 Subject: [PATCH] tosa] Support for Aten[Gelu|GeluBackward] ops (#720) Signed-off-by: Anup Gangwar Co-authored-by: Anup Gangwar --- e2e_testing/torchscript/xfail_sets.py | 2 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 162 +++++++++++++++++++++ 2 files changed, 164 insertions(+) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index a6895e069..d40389072 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -152,4 +152,6 @@ TOSA_PASS_SET = { "ViewNoChangeStaticModule_basic", "UnsafeViewExpandModule_basic", "ReshapeCollapseModule_basic", + "ElementwiseGeluModule_basic", + "GeluBackwardModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 197b3ae45..b4b623519 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2439,6 +2439,166 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +static Value approximateErfOp(ConversionPatternRewriter &rewriter, + Operation *op, Value x) { + // Using: + // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with + // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = + // 0.000972, a4 = 0.078108. + // + // Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4 + + auto outType = x.getType().cast(); + auto loc = op->getLoc(); + auto absX = rewriter.create(loc, outType, x); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).getValue(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}).getValue(); + + auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).getValue(); + auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); + auto sum = rewriter.create(loc, outType, a1X, one); + + auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).getValue(); + auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); + auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); + sum = rewriter.create(loc, outType, sum, a2X); + + auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).getValue(); + auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); + auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); + sum = rewriter.create(loc, outType, sum, a3X); + + auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).getValue(); + auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); + auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); + sum = rewriter.create(loc, outType, sum, a4X); + + auto rcprl = rewriter.create(loc, outType, sum); + auto rcprl2 = + rewriter.create(loc, outType, rcprl, rcprl, /*shift=*/0); + auto rcprl4 = + rewriter.create(loc, outType, rcprl2, rcprl2, /*shift=*/0); + auto erf = rewriter.create(loc, outType, one, rcprl4); + + // Deal with negative x. + auto cond = rewriter.create( + loc, + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x, + zero); + auto negateErf = rewriter.create(loc, outType, erf); + + return rewriter.create(loc, outType, cond, erf, negateErf); +} + +static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, + Operation *op, Value x) { + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).getValue(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}).getValue(); + auto loc = op->getLoc(); + + // buildNormalCdf, mean = zero, sigma = one + auto outType = x.getType(); + auto mean = zero; + Value xMinusMean = rewriter.create(loc, outType, x, mean); + // rsqrt of 2 + Value rsqrt2 = tosa::getConstTensor(rewriter, op, 0.70710678, {}).getValue(); + Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, + /*shift=*/0); + Value erf = approximateErfOp(rewriter, op, erfArg); + Value erfPlus1 = rewriter.create(loc, outType, one, erf); + Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).getValue(); + Value normalCdf = rewriter.create(loc, outType, oneHalf, + erfPlus1, /*shift=*/0); + return normalCdf; +} + +// This lowering is based on Torch to LinAlg lowering. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGeluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + + auto selfElemTy = selfType.getElementType(); + if (!selfElemTy.isa()) { + return op.emitError("Only floating-point datatype legalization supported"); + } + + // TODO: Handle approximate. + std::string approximate; + if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) || + approximate != "none") { + return op.emitError("Unsupported value of approximate"); + } + + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self(), cdf, + /*shift=*/0); + + return success(); +} + +// This lowering is based on Torch to LinAlg lowering. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGeluBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + + auto selfElemTy = selfType.getElementType(); + if (!selfElemTy.isa()) { + return op.emitError("Only floating-point datatype legalization supported"); + } + + // TODO: Handle approximate. + std::string approximate; + if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) || + approximate != "none") { + return op.emitError("Unsupported value of approximate"); + } + + auto loc = op->getLoc(); + + const double cstAlpha0 = 1.12837916709551257390; + const double cstAlpha1 = 0.70710678118654752440; + const double oneHalf = 0.5; + const double kAlpha = cstAlpha0 * cstAlpha1; + + Value kAlphaHalf = + tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}) + .getValue(); + Value negOneHalf = + tosa::getConstTensor(rewriter, op, -0.5, {}).getValue(); + Value inputSquared = rewriter.create( + loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0); + Value negHalfInputSquared = rewriter.create( + loc, selfType, inputSquared, negOneHalf, /*shift=*/0); + Value dinput = + rewriter.create(loc, selfType, negHalfInputSquared); + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self()); + Value dinputInput = rewriter.create(loc, selfType, dinput, + adaptor.self(), /*shift=*/0); + Value dinputInputAlpha = rewriter.create( + loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0); + Value cdfExt = + rewriter.create(loc, selfType, dinputInputAlpha, cdf); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.grad_output(), + cdfExt, + /*shift=*/0); + + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -2993,6 +3153,8 @@ public: INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenDropoutOp); INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target,