Added support for native_dropout (#891)

pull/899/head
Henry Tu 2022-06-03 14:05:57 -04:00 committed by GitHub
parent a635fd2287
commit b7082a8d4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 0 deletions

View File

@ -5644,6 +5644,32 @@ def Torch_AtenDropout_Op : Torch_Op<"aten.dropout_", [
}]; }];
} }
def Torch_AtenNativeDropoutOp : Torch_Op<"aten.native_dropout", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
Torch_FloatType:$p,
AnyTorchOptionalBoolType:$train
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNativeDropoutOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void AtenNativeDropoutOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}
def Torch_AtenTOp : Torch_Op<"aten.t", [ def Torch_AtenTOp : Torch_Op<"aten.t", [
AllowsTypeRefinement, AllowsTypeRefinement,
ReadOnly ReadOnly

View File

@ -446,6 +446,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
emit("aten::t : (Tensor) -> (Tensor)") emit("aten::t : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")