diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index cdec586d5..e5e5082d4 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -699,6 +699,9 @@ STABLEHLO_PASS_SET = { "NewZerosStaticModuleLayoutStrided_basic", "DropoutEvalIntModule_basic", "DropoutEvalFloatModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", "ContiguousModule_basic", "DropoutModule_basic", "ViewCollapseModule_basic", @@ -1258,6 +1261,9 @@ LTC_XFAIL_SET = { "BernoulliModule_basic", "BernoulliPModule_basic", "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionNoneModule_basic", "VarBiasedModule_basic", diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bc478a5dd..14590d6a1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6437,6 +6437,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.native_dropout\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.gelu\"(%arg0: !torch.list, %arg1: !torch.str) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8244,6 +8249,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.optional) -> !torch.tuple {\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int11 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 150a1f976..6033a330f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2128,6 +2128,58 @@ public: return success(); } }; + +class DeomposeAtenNativeDropoutOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeDropoutOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + Value input = op.getInput(); + Value prob = op.getP(); + bool train = false; + if (!op.getTrain().getType().isa()) { + if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { + return rewriter.notifyMatchFailure( + op, "train must be a boolean constant or none"); + } + } + Value noneVal = rewriter.create(loc); + if (!train) { + Value i1Type = + getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1)); + Value inputSize = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), input); + Value trueValue = rewriter.create(loc, 1); + Value trueMask = rewriter.create( + loc, op->getResultTypes()[1], inputSize, trueValue, i1Type, + /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); + rewriter.replaceOp(op, ArrayRef{input, trueMask}); + return success(); + } + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "only support floating type input for training mode"); + } + Value floatOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = rewriter.create(loc, floatOne, prob); + Value boolMask = rewriter.create( + loc, inputType, input, oneMinusP, /*generator=*/noneVal); + Value maskedInput = + rewriter.create(loc, inputType, boolMask, input); + Value output = rewriter.create( + loc, op->getResultTypes()[0], maskedInput, oneMinusP); + rewriter.replaceOp( + op, ArrayRef{ + output, convertTensorToDtype(rewriter, loc, boolMask, + IntegerType::get(context, 1))}); + return success(); + } +}; } // namespace // Decompose aten.var into: aten.var.dim op. @@ -4654,6 +4706,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 7ec4594eb..7fa9c26f8 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -440,6 +440,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 216000970..232ef262d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -203,6 +203,10 @@ def aten〇type_as〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇dropout〡shape(input: List[int], p: float, train: bool) -> List[int]: return upstream_shape_functions.unary(input) +def aten〇native_dropout〡shape(input: List[int], p: float, train: Optional[bool]) -> Tuple[List[int], List[int]]: + shape = upstream_shape_functions.unary(input) + return shape, shape + def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]: return upstream_shape_functions.unary(self) @@ -1458,6 +1462,11 @@ def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: b input_rank, input_dtype = input_rank_dtype return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) +def aten〇native_dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: Optional[bool]) -> Tuple[int, int]: + input_rank, input_dtype = input_rank_dtype + return input_dtype, torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 563dc6417..5e1c05f97 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1786,6 +1786,94 @@ class DropoutTrainModule(torch.nn.Module): def DropoutTrainModule_basic(module, tu: TestUtils): module.forward(tu.rand(1024, 1536)) +# ============================================================================== + + +class DropoutTrainStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1024, 1536], torch.float32, True), + ]) + def forward(self, x): + res = torch.dropout(x, 0.3, train=True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: DropoutTrainStaticShapeModule()) +def DropoutTrainStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + +# ============================================================================== + + +class NativeDropoutEvalFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.native_dropout(x, 0.1, train=False) + + +@register_test_case(module_factory=lambda: NativeDropoutEvalFloatModule()) +def NativeDropoutEvalFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class NativeDropoutTrainModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + res = torch.native_dropout(x, 0.3, train=True) + return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32)) + + +@register_test_case(module_factory=lambda: NativeDropoutTrainModule()) +def NativeDropoutTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + + +# ============================================================================== + + +class NativeDropoutTrainStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1024, 1536], torch.float32, True), + ]) + def forward(self, x): + res = torch.native_dropout(x, 0.3, train=True) + return torch.mean(res[0]), torch.std(res[0]), torch.mean(res[1].to(torch.float32)), torch.std(res[1].to(torch.float32)) + + +@register_test_case(module_factory=lambda: NativeDropoutTrainStaticShapeModule()) +def NativeDropoutTrainStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) # ==============================================================================