mirror of https://github.com/llvm/torch-mlir
Fix nan issue for fp16 torch.randn/randn_like in ConvertAtenUniformOp (#3184)
For ops that use ConvertAtenUniformOp (e.g. torch.randn/randn_like), fp16 datatype returns nan values. Trying to lower [this repro](https://gist.github.com/aviator19941/1c65e658241dea6906ca423f9abaee69) will result in nan's, this PR fixes the issue.pull/3222/head
parent
fab2696489
commit
678c03b762
|
@ -129,6 +129,7 @@ public:
|
|||
Value generator = adaptor.getGenerator();
|
||||
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
|
||||
Type elemTy = resultType.getElementType();
|
||||
Type f64Ty = rewriter.getF64Type();
|
||||
|
||||
if (!isa<mlir::FloatType>(elemTy))
|
||||
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
||||
|
@ -139,8 +140,8 @@ public:
|
|||
"generator is supported");
|
||||
// Get key, min and max used by `linalg.generic` compute payload.
|
||||
Value key = rewriter.create<TorchConversion::GetNextSeedOp>(loc);
|
||||
Value min = convertScalarToDtype(rewriter, loc, from, elemTy);
|
||||
Value max = convertScalarToDtype(rewriter, loc, to, elemTy);
|
||||
Value min = convertScalarToDtype(rewriter, loc, from, f64Ty);
|
||||
Value max = convertScalarToDtype(rewriter, loc, to, f64Ty);
|
||||
|
||||
// Construct the `linalg.generic` op.
|
||||
auto resultRank = resultType.getRank();
|
||||
|
@ -179,11 +180,14 @@ public:
|
|||
|
||||
// res = cast(F64, tempN) * scale + min
|
||||
Value updateFloat =
|
||||
b.create<arith::UIToFPOp>(loc, elemTy, randomVal);
|
||||
b.create<arith::UIToFPOp>(loc, f64Ty, randomVal);
|
||||
Value updateScaled =
|
||||
b.create<arith::MulFOp>(loc, updateFloat, scale);
|
||||
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
|
||||
b.create<linalg::YieldOp>(loc, res);
|
||||
Value truncRes = res;
|
||||
if (elemTy.isa<Float16Type, Float32Type>())
|
||||
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
|
||||
b.create<linalg::YieldOp>(loc, truncRes);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
|
|
Loading…
Reference in New Issue