[MLIR][TORCH] Add support for tanh approximation for Gelu op (#2941)

Fixes https://github.com/nod-ai/SHARK-Turbine/issues/461

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/2956/head
Vivek Khandelwal 2024-02-27 19:26:01 +05:30 committed by GitHub
parent d81747eadb
commit d628b5fd06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 4 deletions

View File

@ -511,11 +511,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
// TODO: Take approximation into account.
std::string approximate;
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate)) ||
approximate != "none")
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) {
gelu.emitError(
"unimplemented: expected approximate to be a constant str");
return nullptr;
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
}
if (approximate == "none") {
Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
}
if (approximate == "tanh") {
// GELU(x)=0.5x(1+Tanh((2/π)^1/2 * (x+0.044715x^3)))
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
Value cstThree = b.create<arith::ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3));
Value xCube = b.create<math::FPowIOp>(loc, payloadArgs[0], cstThree);
Type elementType = payloadArgs[0].getType();
Value cstAlpha = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.044715));
Value xCubeMulAlpha = b.create<arith::MulFOp>(loc, xCube, cstAlpha);
Value xPlusXCubeMulAlpha =
b.create<arith::AddFOp>(loc, payloadArgs[0], xCubeMulAlpha);
Value cstBeta = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.7977240352174656));
Value betaMulX =
b.create<arith::MulFOp>(loc, cstBeta, xPlusXCubeMulAlpha);
Value tanh = b.create<math::TanhOp>(loc, betaMulX);
Value cstOne =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
Value onePlusTanh = b.create<arith::AddFOp>(loc, cstOne, tanh);
Value cstHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value multiplier = b.create<arith::MulFOp>(loc, cstHalf, onePlusTanh);
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
}
gelu.emitError("unimplemented: approximate value should be none or tanh");
return nullptr;
}
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!geluBackward.getType()

View File

@ -518,6 +518,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNanToNumModule_Basic",

View File

@ -853,6 +853,29 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseGeluApproximateTanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU(approximate="tanh")
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return self.gelu(x)
@register_test_case(module_factory=lambda: ElementwiseGeluApproximateTanhModule())
def ElementwiseGeluApproximateTanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-0.5, high=0.5))
# ==============================================================================
class ElementwiseSeluModule(torch.nn.Module):
def __init__(self):