diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index e861a1877..724430401 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -226,7 +226,8 @@ public: typeConverter->convertType(op.getType()).cast(); Type resultElementType; if (op.getDtype().getType().isa()) { - resultElementType = resultType.getElementType(); + resultElementType = getDefaultDtypeForTorchScalar( + Torch::FloatType::get(op->getContext())); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e9553cd07..040cc5f4e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3923,8 +3923,14 @@ public: LogicalResult matchAndRewrite(AtenRandnGeneratorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type resultType = op.getType(); + auto resultType = op.getType().cast(); + if (!resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } + + Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); Value none = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); @@ -3936,11 +3942,13 @@ public: loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); Value emptyTensorA = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), + loc, resultType, op.getSize(), /*dtype=*/dtype, + /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value emptyTensorB = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), + loc, resultType, op.getSize(), /*dtype=*/dtype, + /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index c292d2438..22076e031 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -459,6 +459,30 @@ def RandnGeneratorModule_basic(module, tu: TestUtils): # ============================================================================== + +class RandnGeneratorF64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64) + std = torch.std(a) + return std + + +@register_test_case(module_factory=lambda: RandnGeneratorF64Module()) +def RandnGeneratorF64Module_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + class RandnLikeModule(torch.nn.Module): def __init__(self): super().__init__()