mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for tanh approximation for Gelu op (#2941)
Fixes https://github.com/nod-ai/SHARK-Turbine/issues/461 Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2956/head
parent
d81747eadb
commit
d628b5fd06
|
@ -511,11 +511,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
// TODO: Take approximation into account.
|
||||
std::string approximate;
|
||||
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none")
|
||||
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) {
|
||||
gelu.emitError(
|
||||
"unimplemented: expected approximate to be a constant str");
|
||||
return nullptr;
|
||||
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
|
||||
}
|
||||
if (approximate == "none") {
|
||||
Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
|
||||
}
|
||||
if (approximate == "tanh") {
|
||||
// GELU(x)=0.5∗x∗(1+Tanh((2/π)^1/2 * (x+0.044715∗x^3)))
|
||||
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
||||
Value cstThree = b.create<arith::ConstantOp>(
|
||||
loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3));
|
||||
Value xCube = b.create<math::FPowIOp>(loc, payloadArgs[0], cstThree);
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
Value cstAlpha = b.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.044715));
|
||||
Value xCubeMulAlpha = b.create<arith::MulFOp>(loc, xCube, cstAlpha);
|
||||
Value xPlusXCubeMulAlpha =
|
||||
b.create<arith::AddFOp>(loc, payloadArgs[0], xCubeMulAlpha);
|
||||
Value cstBeta = b.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.7977240352174656));
|
||||
Value betaMulX =
|
||||
b.create<arith::MulFOp>(loc, cstBeta, xPlusXCubeMulAlpha);
|
||||
Value tanh = b.create<math::TanhOp>(loc, betaMulX);
|
||||
Value cstOne =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
||||
Value onePlusTanh = b.create<arith::AddFOp>(loc, cstOne, tanh);
|
||||
Value cstHalf =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
|
||||
Value multiplier = b.create<arith::MulFOp>(loc, cstHalf, onePlusTanh);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
|
||||
}
|
||||
gelu.emitError("unimplemented: approximate value should be none or tanh");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
|
||||
if (!geluBackward.getType()
|
||||
|
|
|
@ -518,6 +518,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseFloorIntModule_basic",
|
||||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseGeluModule_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseLeakyReluStaticModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseNanToNumModule_Basic",
|
||||
|
|
|
@ -853,6 +853,29 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeluApproximateTanhModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gelu = torch.nn.GELU(approximate="tanh")
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.gelu(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseGeluApproximateTanhModule())
|
||||
def ElementwiseGeluApproximateTanhModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-0.5, high=0.5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSeluModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue