Fix nan issue for fp16 torch.randn/randn_like in ConvertAtenUniformOp (#3184)

For ops that use ConvertAtenUniformOp (e.g. torch.randn/randn_like),
fp16 datatype returns nan values. Trying to lower [this
repro](https://gist.github.com/aviator19941/1c65e658241dea6906ca423f9abaee69)
will result in nan's, this PR fixes the issue.
pull/3222/head
Avinash Sharma 2024-04-23 23:58:08 -07:00 committed by GitHub
parent fab2696489
commit 678c03b762
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 4 deletions

View File

@ -129,6 +129,7 @@ public:
Value generator = adaptor.getGenerator();
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType();
Type f64Ty = rewriter.getF64Type();
if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type");
@ -139,8 +140,8 @@ public:
"generator is supported");
// Get key, min and max used by `linalg.generic` compute payload.
Value key = rewriter.create<TorchConversion::GetNextSeedOp>(loc);
Value min = convertScalarToDtype(rewriter, loc, from, elemTy);
Value max = convertScalarToDtype(rewriter, loc, to, elemTy);
Value min = convertScalarToDtype(rewriter, loc, from, f64Ty);
Value max = convertScalarToDtype(rewriter, loc, to, f64Ty);
// Construct the `linalg.generic` op.
auto resultRank = resultType.getRank();
@ -179,11 +180,14 @@ public:
// res = cast(F64, tempN) * scale + min
Value updateFloat =
b.create<arith::UIToFPOp>(loc, elemTy, randomVal);
b.create<arith::UIToFPOp>(loc, f64Ty, randomVal);
Value updateScaled =
b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
b.create<linalg::YieldOp>(loc, res);
Value truncRes = res;
if (elemTy.isa<Float16Type, Float32Type>())
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
b.create<linalg::YieldOp>(loc, truncRes);
})
.getResult(0);