mirror of https://github.com/llvm/torch-mlir
Add e2e testing for aten_tanh_backward op.
The e2e testing for aten_tanh_backward op has been added. The testing is done for ref_backend.pull/411/head snapshot-20211109.74
parent
2764e86f02
commit
909f7d7171
|
@ -33,3 +33,20 @@ class SoftmaxBackwardModule(torch.nn.Module):
|
|||
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||
|
||||
class TanhBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, out_grad, output):
|
||||
return torch.ops.aten.tanh_backward(out_grad, output)
|
||||
|
||||
@register_test_case(module_factory=lambda: TanhBackwardModule())
|
||||
def TanhBackward_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3), torch.randn(3, 3))
|
||||
|
|
|
@ -2830,3 +2830,18 @@ def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
|||
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::tanh_backward : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$output
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$grad_output `,` $output attr-dict `:` type($grad_output) `,` type($output) `->` type($result)";
|
||||
}
|
||||
|
||||
|
|
|
@ -178,6 +178,43 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// AtenTanhBackwardOp(gradOutput, output) =>
|
||||
// result = gradOutput * (1 - output^2)
|
||||
// To get away from broadcasts the above formula is expanded i.e.,
|
||||
// result = gradOutput - (gradOutput * output^2)
|
||||
namespace {
|
||||
class DecomposeAtenTanhBackwardOp
|
||||
: public OpRewritePattern<AtenTanhBackwardOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTanhBackwardOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value gradOutput = op.grad_output();
|
||||
|
||||
// `output` is the value flowing out from tanh. Hence, tanh(x) = output.
|
||||
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
||||
Value output = op.output();
|
||||
|
||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value tanhSquare =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, output);
|
||||
Value gradMulTanhSquare = rewriter.create<AtenMulTensorOp>(
|
||||
loc, tensorType, tanhSquare, gradOutput);
|
||||
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
Value newGrad = rewriter.create<AtenSubTensorOp>(
|
||||
loc, tensorType, gradOutput, gradMulTanhSquare, alpha);
|
||||
rewriter.replaceOp(op, newGrad);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.log_softmax op into: log(softmax(x))
|
||||
namespace {
|
||||
class DecomposeAtenLogSoftmaxIntOp
|
||||
|
@ -271,6 +308,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenSizeOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
||||
target.addIllegalOp<AtenTanhBackwardOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
int lhsRank = getTensorRank(op.self());
|
||||
|
|
|
@ -230,7 +230,8 @@ public:
|
|||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp>(op)) {
|
||||
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||
AtenTanhBackwardOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -624,6 +624,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
|
||||
# backprop ops
|
||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
|
|
Loading…
Reference in New Issue