[RefineTypes] Add Float16Type dtype knowledge support for trivial ops

-- This commit adds Float16Type dtype knowledge support for trivial ops.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
tanyo/fix_upstream
Abhishek Varma 2022-11-30 16:17:51 +00:00 committed by Vivek Khandelwal
parent 0983a7f93a
commit 47f67853ac
1 changed files with 2 additions and 2 deletions

View File

@ -706,7 +706,7 @@ void TypeAnalysis::visitOperation(Operation *op,
return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
} }
// Dtype is always float32, except for bfloat16, float64 and nullptr. // Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp, if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) { AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
@ -715,7 +715,7 @@ void TypeAnalysis::visitOperation(Operation *op,
Type dtype = operands[0]->getValue().dtype; Type dtype = operands[0]->getValue().dtype;
if (dtype) { if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext()); knowledge.dtype = Float32Type::get(op->getContext());
if (dtype.isa<BFloat16Type, Float64Type>()) if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
knowledge.dtype = dtype; knowledge.dtype = dtype;
} }
incorporateKnowledge(op->getResult(0), knowledge); incorporateKnowledge(op->getResult(0), knowledge);