mirror of https://github.com/llvm/torch-mlir
tosa] Support for Aten[Gelu|GeluBackward] ops (#720)
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com> Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>pull/726/head snapshot-20220331.360
parent
c17c0a6ba2
commit
ccf924d3df
|
@ -152,4 +152,6 @@ TOSA_PASS_SET = {
|
|||
"ViewNoChangeStaticModule_basic",
|
||||
"UnsafeViewExpandModule_basic",
|
||||
"ReshapeCollapseModule_basic",
|
||||
"ElementwiseGeluModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
}
|
||||
|
|
|
@ -2439,6 +2439,166 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value x) {
|
||||
// Using:
|
||||
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
|
||||
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
|
||||
// 0.000972, a4 = 0.078108.
|
||||
//
|
||||
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
|
||||
|
||||
auto outType = x.getType().cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
|
||||
|
||||
auto a1 = tosa::getConstTensor<float>(rewriter, op, 0.278393, {}).getValue();
|
||||
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
|
||||
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
|
||||
|
||||
auto a2 = tosa::getConstTensor<float>(rewriter, op, 0.230389, {}).getValue();
|
||||
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
|
||||
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
|
||||
|
||||
auto a3 = tosa::getConstTensor<float>(rewriter, op, 0.000972, {}).getValue();
|
||||
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
|
||||
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
|
||||
|
||||
auto a4 = tosa::getConstTensor<float>(rewriter, op, 0.078108, {}).getValue();
|
||||
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
|
||||
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
|
||||
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
|
||||
|
||||
auto rcprl = rewriter.create<tosa::ReciprocalOp>(loc, outType, sum);
|
||||
auto rcprl2 =
|
||||
rewriter.create<tosa::MulOp>(loc, outType, rcprl, rcprl, /*shift=*/0);
|
||||
auto rcprl4 =
|
||||
rewriter.create<tosa::MulOp>(loc, outType, rcprl2, rcprl2, /*shift=*/0);
|
||||
auto erf = rewriter.create<tosa::SubOp>(loc, outType, one, rcprl4);
|
||||
|
||||
// Deal with negative x.
|
||||
auto cond = rewriter.create<tosa::GreaterEqualOp>(
|
||||
loc,
|
||||
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x,
|
||||
zero);
|
||||
auto negateErf = rewriter.create<tosa::NegateOp>(loc, outType, erf);
|
||||
|
||||
return rewriter.create<tosa::SelectOp>(loc, outType, cond, erf, negateErf);
|
||||
}
|
||||
|
||||
static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value x) {
|
||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).getValue();
|
||||
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}).getValue();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// buildNormalCdf, mean = zero, sigma = one
|
||||
auto outType = x.getType();
|
||||
auto mean = zero;
|
||||
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
||||
// rsqrt of 2
|
||||
Value rsqrt2 = tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
|
||||
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
||||
/*shift=*/0);
|
||||
Value erf = approximateErfOp(rewriter, op, erfArg);
|
||||
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
|
||||
Value oneHalf = tosa::getConstTensor<float>(rewriter, op, 0.5, {}).getValue();
|
||||
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
|
||||
erfPlus1, /*shift=*/0);
|
||||
return normalCdf;
|
||||
}
|
||||
|
||||
// This lowering is based on Torch to LinAlg lowering.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||
AtenGeluOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
if (!selfElemTy.isa<mlir::FloatType>()) {
|
||||
return op.emitError("Only floating-point datatype legalization supported");
|
||||
}
|
||||
|
||||
// TODO: Handle approximate.
|
||||
std::string approximate;
|
||||
if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none") {
|
||||
return op.emitError("Unsupported value of approximate");
|
||||
}
|
||||
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self());
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(), cdf,
|
||||
/*shift=*/0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// This lowering is based on Torch to LinAlg lowering.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||
AtenGeluBackwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
if (!selfElemTy.isa<mlir::FloatType>()) {
|
||||
return op.emitError("Only floating-point datatype legalization supported");
|
||||
}
|
||||
|
||||
// TODO: Handle approximate.
|
||||
std::string approximate;
|
||||
if (!matchPattern(op.approximate(), m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none") {
|
||||
return op.emitError("Unsupported value of approximate");
|
||||
}
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
const double cstAlpha0 = 1.12837916709551257390;
|
||||
const double cstAlpha1 = 0.70710678118654752440;
|
||||
const double oneHalf = 0.5;
|
||||
const double kAlpha = cstAlpha0 * cstAlpha1;
|
||||
|
||||
Value kAlphaHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, kAlpha * oneHalf, {})
|
||||
.getValue();
|
||||
Value negOneHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, -0.5, {}).getValue();
|
||||
Value inputSquared = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0);
|
||||
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, inputSquared, negOneHalf, /*shift=*/0);
|
||||
Value dinput =
|
||||
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.self());
|
||||
Value dinputInput = rewriter.create<tosa::MulOp>(loc, selfType, dinput,
|
||||
adaptor.self(), /*shift=*/0);
|
||||
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
|
||||
loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0);
|
||||
Value cdfExt =
|
||||
rewriter.create<tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.grad_output(),
|
||||
cdfExt,
|
||||
/*shift=*/0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -2993,6 +3153,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDropoutOp);
|
||||
INSERT_ATENOP_PATTERN(AtenViewOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
Loading…
Reference in New Issue