diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 38a88fe16..a558db372 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" " %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" @@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int6 = torch.constant.int 6\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c4776231..8afccbba0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3669,9 +3669,38 @@ public: return success(); } }; -} // namespace -namespace { +// aten.normal_functional(mean, sigma) = randn() * sigma + mean. +class DecomposeAtenNormalFunctionalOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, + PatternRewriter &rewriter) const override { + if (!op.getGenerator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + Location loc = op.getLoc(); + Type resultType = op.getType(); + Value std = op.getStd(); + Value mean = op.getMean(); + + Value none = rewriter.create(loc); + Value one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value randN = rewriter.create( + loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + Value stdRandN = + rewriter.create(loc, resultType, randN, std); + rewriter.replaceOpWithNewOp(op, resultType, stdRandN, + mean, /*alpha=*/one); + return success(); + } +}; + template class DecomposeAtenAddCLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6591,6 +6620,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index e76adb9b8..da7811ad0 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8a440c16b..7ba4c309c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1484,6 +1484,7 @@ LTC_XFAIL_SET = { "VarMeanUnbiasedModule_basic", "RandnLikeModule_basic", "RandnLikeDtypeModule_basic", + "NormalFunctionalModule_basic", "BernoulliFloatModule_basic", "BernoulliModule_basic", "BernoulliPModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 640c0bbfd..bf2f45e37 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -902,6 +902,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]: + return self + def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory) @@ -3822,6 +3825,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O assert not is_integer_dtype(dtype) return dtype +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype is None: + return torch.float32 + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function([Invocation(size=[1], generator=None), Invocation(size=[1], generator=None, dtype=torch.float32), ErrorInvocation(size=[1], generator=None, dtype=torch.int32), diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index dedd2b398..2b8e186ff 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -605,3 +605,24 @@ class RandnLikeDtypeModule(torch.nn.Module): @register_test_case(module_factory=lambda: RandnLikeDtypeModule()) def RandnLikeDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(256, 1024).double()) +# ============================================================================== + +class NormalFunctionalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0) + mean = torch.mean(a) + std = torch.std(a) + return mean, std + + +@register_test_case(module_factory=lambda: NormalFunctionalModule()) +def NormalFunctionalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2048, 4096).double())