tosa] Support for Aten[Gelu|GeluBackward] ops (#720)

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>

Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>
pull/726/head snapshot-20220331.360
Anup Gangwar 2022-03-30 19:00:55 -05:00 committed by GitHub
parent c17c0a6ba2
commit ccf924d3df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 164 additions and 0 deletions

View File

@ -152,4 +152,6 @@ TOSA_PASS_SET = {
"ViewNoChangeStaticModule_basic", "ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic", "UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic", "ReshapeCollapseModule_basic",
"ElementwiseGeluModule_basic",
"GeluBackwardModule_basic",
} }

View File

@ -2439,6 +2439,166 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
return success(); return success();
} }
static Value approximateErfOp(ConversionPatternRewriter &rewriter,
Operation *op, Value x) {
// Using:
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
// 0.000972, a4 = 0.078108.
//
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
auto outType = x.getType().cast<TensorType>();
auto loc = op->getLoc();
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).getValue();
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).getValue();
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).getValue();
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).getValue();
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
auto rcprl = rewriter.create<tosa::ReciprocalOp>(loc, outType, sum);
auto rcprl2 =
rewriter.create<tosa::MulOp>(loc, outType, rcprl, rcprl, /*shift=*/0);
auto rcprl4 =
rewriter.create<tosa::MulOp>(loc, outType, rcprl2, rcprl2, /*shift=*/0);
auto erf = rewriter.create<tosa::SubOp>(loc, outType, one, rcprl4);
// Deal with negative x.
auto cond = rewriter.create<tosa::GreaterEqualOp>(
loc,
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x,
zero);
auto negateErf = rewriter.create<tosa::NegateOp>(loc, outType, erf);
return rewriter.create<tosa::SelectOp>(loc, outType, cond, erf, negateErf);
}
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
Operation *op, Value x) {
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
auto loc = op->getLoc();
// buildNormalCdf, mean = zero, sigma = one
auto outType = x.getType();
auto mean = zero;
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
// rsqrt of 2
Value rsqrt2 = tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
/*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg);
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).getValue();
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
erfPlus1, /*shift=*/0);
return normalCdf;
}
// This lowering is based on Torch to LinAlg lowering.
template <>
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
AtenGeluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isa<mlir::FloatType>()) {
return op.emitError("Only floating-point datatype legalization supported");
}
// TODO: Handle approximate.
std::string approximate;
if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) ||
approximate != "none") {
return op.emitError("Unsupported value of approximate");
}
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self());
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(), cdf,
/*shift=*/0);
return success();
}
// This lowering is based on Torch to LinAlg lowering.
template <>
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
AtenGeluBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isa<mlir::FloatType>()) {
return op.emitError("Only floating-point datatype legalization supported");
}
// TODO: Handle approximate.
std::string approximate;
if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) ||
approximate != "none") {
return op.emitError("Unsupported value of approximate");
}
auto loc = op->getLoc();
const double cstAlpha0 = 1.12837916709551257390;
const double cstAlpha1 = 0.70710678118654752440;
const double oneHalf = 0.5;
const double kAlpha = cstAlpha0 * cstAlpha1;
Value kAlphaHalf =
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {})
.getValue();
Value negOneHalf =
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).getValue();
Value inputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0);
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
loc, selfType, inputSquared, negOneHalf, /*shift=*/0);
Value dinput =
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self());
Value dinputInput = rewriter.create<tosa::MulOp>(loc, selfType, dinput,
adaptor.self(), /*shift=*/0);
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0);
Value cdfExt =
rewriter.create<tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.grad_output(),
cdfExt,
/*shift=*/0);
return success();
}
template <typename AtenOpT, typename TosaOpT> template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> { class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public: public:
@ -2993,6 +3153,8 @@ public:
INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenDropoutOp); INSERT_ATENOP_PATTERN(AtenDropoutOp);
INSERT_ATENOP_PATTERN(AtenViewOp); INSERT_ATENOP_PATTERN(AtenViewOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,