diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 3b18844df..6519a2723 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -129,6 +129,7 @@ public: Value generator = adaptor.getGenerator(); RankedTensorType resultType = self.getType().cast(); Type elemTy = resultType.getElementType(); + Type f64Ty = rewriter.getF64Type(); if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); @@ -139,8 +140,8 @@ public: "generator is supported"); // Get key, min and max used by `linalg.generic` compute payload. Value key = rewriter.create(loc); - Value min = convertScalarToDtype(rewriter, loc, from, elemTy); - Value max = convertScalarToDtype(rewriter, loc, to, elemTy); + Value min = convertScalarToDtype(rewriter, loc, from, f64Ty); + Value max = convertScalarToDtype(rewriter, loc, to, f64Ty); // Construct the `linalg.generic` op. auto resultRank = resultType.getRank(); @@ -179,11 +180,14 @@ public: // res = cast(F64, tempN) * scale + min Value updateFloat = - b.create(loc, elemTy, randomVal); + b.create(loc, f64Ty, randomVal); Value updateScaled = b.create(loc, updateFloat, scale); Value res = b.create(loc, updateScaled, min); - b.create(loc, res); + Value truncRes = res; + if (elemTy.isa()) + truncRes = b.create(loc, elemTy, res); + b.create(loc, truncRes); }) .getResult(0);