[TORCH][MLIR] Add E2E support for `aten.gelu_backward` operation. (#418)

This commit adds new operation `aten.gelu_backward` in the aten
dialect and adds lowering of this operation from aten to linalg.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
pull/426/head snapshot-20211117.89
Prateek Gupta 2021-11-17 14:59:38 +05:30 committed by GitHub
parent 0fe70994e5
commit ecf78b9849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 15 deletions

View File

@ -53,3 +53,26 @@ class TanhBackwardModule(torch.nn.Module):
def TanhBackward_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3), torch.randn(3, 3))
# ==============================================================================
class GeluBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.gelu_backward(grad, input)
@register_test_case(module_factory=lambda: GeluBackwardModule())
def GeluBackwardModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3), tu.rand(5, 3))

View File

@ -2845,3 +2845,18 @@ def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
let assemblyFormat = "$grad_output `,` $output attr-dict `:` type($grad_output) `,` type($output) `->` type($result)";
}
def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::gelu_backward : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad,
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$grad `,` $self attr-dict `:` type($grad) `,` type($self) `->` type($result)";
}

View File

@ -1353,6 +1353,37 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
}
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!geluBackward.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
geluBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Type elementType = payloadArgs[1].getType();
Value constant0 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 1.12837916709551257390));
Value constant1 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.70710678118654752440));
Value oneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value kAlpha = b.create<arith::MulFOp>(loc, constant0, constant1);
Value kAlphaHalf = b.create<arith::MulFOp>(loc, kAlpha, oneHalf);
Value negOneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, -0.5));
Value inputSquared =
b.create<arith::MulFOp>(loc, payloadArgs[1], payloadArgs[1]);
Value negHalfInputSquared =
b.create<arith::MulFOp>(loc, inputSquared, negOneHalf);
Value dinput = b.create<math::ExpOp>(loc, negHalfInputSquared);
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]);
Value dinputInput = b.create<arith::MulFOp>(loc, dinput, payloadArgs[1]);
Value dinputInputAlpha =
b.create<arith::MulFOp>(loc, dinputInput, kAlphaHalf);
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
}
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(add.getType())
@ -1716,8 +1747,8 @@ struct ConvertElementwiseOp : ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
@ -2871,12 +2902,12 @@ public:
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>();
target.addIllegalOp<
AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);

View File

@ -224,13 +224,14 @@ public:
visitOperation(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenEqScalarOp,
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp,
AtenNeScalarOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp,
AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenTanhBackwardOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -625,6 +625,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::gelu_backward : (Tensor, Tensor) -> (Tensor)")
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):