From b7082a8d4ec1168f05271b6f16f592932e14c640 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Fri, 3 Jun 2022 14:05:57 -0400 Subject: [PATCH] Added support for native_dropout (#891) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + 2 files changed, 27 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index d2027cbd0..69009dd6f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5644,6 +5644,32 @@ def Torch_AtenDropout_Op : Torch_Op<"aten.dropout_", [ }]; } +def Torch_AtenNativeDropoutOp : Torch_Op<"aten.native_dropout", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + Torch_FloatType:$p, + AnyTorchOptionalBoolType:$train + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNativeDropoutOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenNativeDropoutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + def Torch_AtenTOp : Torch_Op<"aten.t", [ AllowsTypeRefinement, ReadOnly diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 5dea74bea..7f6199f85 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -446,6 +446,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") + emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::t : (Tensor) -> (Tensor)") emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")