[RefineTypes] Add Float16Type dtype knowledge support for trivial ops

-- This commit adds Float16Type dtype knowledge support for trivial ops.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
tanyo/fix_upstream
Abhishek Varma 2022-11-30 16:17:51 +00:00 committed by Vivek Khandelwal
parent 0983a7f93a
commit 47f67853ac
1 changed files with 2 additions and 2 deletions

View File

@ -706,7 +706,7 @@ void TypeAnalysis::visitOperation(Operation *op,
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}
// Dtype is always float32, except for bfloat16, float64 and nullptr.
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
@ -715,7 +715,7 @@ void TypeAnalysis::visitOperation(Operation *op,
Type dtype = operands[0]->getValue().dtype;
if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (dtype.isa<BFloat16Type, Float64Type>())
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
knowledge.dtype = dtype;
}
incorporateKnowledge(op->getResult(0), knowledge);