mirror of https://github.com/llvm/torch-mlir
Added support for native_dropout_backward (#892)
parent
b7082a8d4e
commit
650f5a5008
|
@ -7574,6 +7574,31 @@ def Torch_AtenNativeBatchNormBackwardOp : Torch_Op<"aten.native_batch_norm_backw
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$mask,
|
||||
Torch_FloatType:$scale
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNativeDropoutBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenNativeDropoutBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -540,6 +540,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::embedding_dense_backward : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||
|
||||
# ==========================================================================
|
||||
# `prim::` namespace.
|
||||
|
|
Loading…
Reference in New Issue