Add e2e testing for aten_tanh_backward op.

The e2e testing for aten_tanh_backward op has been added.
The testing is done for ref_backend.
pull/411/head snapshot-20211109.74
Prashant Kumar 2021-11-09 12:25:04 +00:00 committed by Yi Zhang
parent 2764e86f02
commit 909f7d7171
5 changed files with 74 additions and 1 deletions

View File

@ -33,3 +33,20 @@ class SoftmaxBackwardModule(torch.nn.Module):
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
class TanhBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, out_grad, output):
return torch.ops.aten.tanh_backward(out_grad, output)
@register_test_case(module_factory=lambda: TanhBackwardModule())
def TanhBackward_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3), torch.randn(3, 3))

View File

@ -2830,3 +2830,18 @@ def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
}
def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::tanh_backward : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$output
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$grad_output `,` $output attr-dict `:` type($grad_output) `,` type($output) `->` type($result)";
}

View File

@ -178,6 +178,43 @@ public:
};
} // namespace
// AtenTanhBackwardOp(gradOutput, output) =>
// result = gradOutput * (1 - output^2)
// To get away from broadcasts the above formula is expanded i.e.,
// result = gradOutput - (gradOutput * output^2)
namespace {
class DecomposeAtenTanhBackwardOp
: public OpRewritePattern<AtenTanhBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value gradOutput = op.grad_output();
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
Value output = op.output();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value tanhSquare =
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
loc, tensorType, tanhSquare, gradOutput);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
Value newGrad = rewriter.create<AtenSubTensorOp>(
loc, tensorType, gradOutput, gradMulTanhSquare, alpha);
rewriter.replaceOp(op, newGrad);
return success();
}
};
} // namespace
// Decompose aten.log_softmax op into: log(softmax(x))
namespace {
class DecomposeAtenLogSoftmaxIntOp
@ -271,6 +308,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
patterns.add<DecomposeAtenTanhBackwardOp>(context);
target.addIllegalOp<AtenTanhBackwardOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.self());

View File

@ -230,7 +230,8 @@ public:
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp>(op)) {
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
AtenTanhBackwardOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -624,6 +624,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):