From 47f67853ac0022389154911c64c3319a394b2969 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 30 Nov 2022 16:17:51 +0000 Subject: [PATCH] [RefineTypes] Add Float16Type dtype knowledge support for trivial ops -- This commit adds Float16Type dtype knowledge support for trivial ops. Signed-off-by: Abhishek Varma --- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 96e9e78a2..cbe1d5b48 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -706,7 +706,7 @@ void TypeAnalysis::visitOperation(Operation *op, return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } - // Dtype is always float32, except for bfloat16, float64 and nullptr. + // Dtype is always float32, except for bfloat16, float16, float64 and nullptr. if (isa(op)) { @@ -715,7 +715,7 @@ void TypeAnalysis::visitOperation(Operation *op, Type dtype = operands[0]->getValue().dtype; if (dtype) { knowledge.dtype = Float32Type::get(op->getContext()); - if (dtype.isa()) + if (dtype.isa()) knowledge.dtype = dtype; } incorporateKnowledge(op->getResult(0), knowledge);