mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for `aten.gelu_backward` operation. (#418)
This commit adds new operation `aten.gelu_backward` in the aten dialect and adds lowering of this operation from aten to linalg. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/426/head snapshot-20211117.89
parent
0fe70994e5
commit
ecf78b9849
|
@ -53,3 +53,26 @@ class TanhBackwardModule(torch.nn.Module):
|
|||
def TanhBackward_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3), torch.randn(3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class GeluBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gelu = torch.nn.GELU()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.gelu_backward(grad, input)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GeluBackwardModule())
|
||||
def GeluBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3), tu.rand(5, 3))
|
||||
|
||||
|
||||
|
|
|
@ -2845,3 +2845,18 @@ def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
|
|||
let assemblyFormat = "$grad_output `,` $output attr-dict `:` type($grad_output) `,` type($output) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gelu_backward : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad,
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$grad `,` $self attr-dict `:` type($grad) `,` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
|
|
|
@ -1353,6 +1353,37 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
|
||||
}
|
||||
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
|
||||
if (!geluBackward.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
geluBackward.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type elementType = payloadArgs[1].getType();
|
||||
Value constant0 = b.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 1.12837916709551257390));
|
||||
Value constant1 = b.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.70710678118654752440));
|
||||
Value oneHalf =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
|
||||
Value kAlpha = b.create<arith::MulFOp>(loc, constant0, constant1);
|
||||
Value kAlphaHalf = b.create<arith::MulFOp>(loc, kAlpha, oneHalf);
|
||||
Value negOneHalf =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, -0.5));
|
||||
Value inputSquared =
|
||||
b.create<arith::MulFOp>(loc, payloadArgs[1], payloadArgs[1]);
|
||||
Value negHalfInputSquared =
|
||||
b.create<arith::MulFOp>(loc, inputSquared, negOneHalf);
|
||||
Value dinput = b.create<math::ExpOp>(loc, negHalfInputSquared);
|
||||
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]);
|
||||
Value dinputInput = b.create<arith::MulFOp>(loc, dinput, payloadArgs[1]);
|
||||
Value dinputInputAlpha =
|
||||
b.create<arith::MulFOp>(loc, dinputInput, kAlphaHalf);
|
||||
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
|
||||
}
|
||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(add.getType())
|
||||
|
@ -1716,8 +1747,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||
AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
|
||||
|
@ -2871,12 +2902,12 @@ public:
|
|||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBatchNormOp>();
|
||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>();
|
||||
target.addIllegalOp<
|
||||
AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
|
||||
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -224,13 +224,14 @@ public:
|
|||
visitOperation(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final {
|
||||
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenEqScalarOp,
|
||||
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
|
||||
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
|
||||
AtenGeluBackwardOp, AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp,
|
||||
AtenNeScalarOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp,
|
||||
AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
|
||||
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
|
||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||
AtenTanhBackwardOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
|
|
@ -625,6 +625,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
# backprop ops
|
||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::gelu_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
|
|
Loading…
Reference in New Issue