From 539511c19b1acef704d8399f814e3a847cc6a719 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Mon, 29 Nov 2021 12:30:03 -0600 Subject: [PATCH] Add dropout op (#436) Co-authored-by: dan --- e2e_testing/torchscript/basic.py | 20 +++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 16 ++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 29 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + 5 files changed, 67 insertions(+), 1 deletion(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 63fedc348..6d5034d8d 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -738,3 +738,23 @@ class AddCDivModule(torch.nn.Module): @register_test_case(module_factory=lambda: AddCDivModule()) def AddCDivModule_basic(module, tu: TestUtils): module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3)) + + +# ============================================================================== + +class DropoutModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.dropout(x, 0.0, False) + + +@register_test_case(module_factory=lambda: DropoutModule()) +def DropoutModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 2f776353c..c322f0ae8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -2199,6 +2199,22 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)"; } +def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::dropout : (Tensor, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + Torch_FloatType:$p, + Torch_BoolType:$train + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)"; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 93f0dfa55..018b95581 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1132,6 +1132,33 @@ public: }; } // namespace +namespace { +class ConvertAtenDropoutOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDropoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + bool train; + if (!matchPattern(op.train(), m_TorchConstantBool(&train))) + return rewriter.notifyMatchFailure(op, + "Expected train to be constant bool."); + + if (train) + return failure(); + auto resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.input()); + return success(); + } +}; +} // namespace + namespace { // See comments at in convertMmOp and the heading for this section for general // considerations. This function needs to be auto-generated. @@ -3035,6 +3062,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 0697a174b..746cf112f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -231,7 +231,7 @@ public: AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, - AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, + AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp>( op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 37d7f4d68..e446cc14a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -569,6 +569,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::Int.Tensor : (Tensor) -> (int)") + emit("aten::dropout : (Tensor, float, bool) -> (Tensor)") # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)